fix : 修复 MMLU 评测脚本数据源和依赖
- 数据源改为 Berkeley data.tar(GitHub zip 不含数据文件) - urllib 替换为 requests,支持代理下载 - zip 解压替换为 tar,增加目录 flatten 逻辑 - 添加 model.eval() 确保推理模式正确
This commit is contained in:
parent
f521a30b22
commit
a923e0a23a
|
|
@ -5,9 +5,9 @@ import csv
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
import tarfile
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
|
@ -15,7 +15,7 @@ import tqdm
|
|||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
|
||||
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
|
||||
MMLU_SUBJECTS = [
|
||||
"abstract_algebra",
|
||||
"anatomy",
|
||||
|
|
@ -78,23 +78,37 @@ MMLU_SUBJECTS = [
|
|||
|
||||
|
||||
def _download_and_extract(url: str, data_dir: str):
|
||||
zip_path = os.path.join(data_dir, "mmlu.zip")
|
||||
tar_path = os.path.join(data_dir, "data.tar")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
print(f"Downloading MMLU data from {url}...")
|
||||
urllib.request.urlretrieve(url, zip_path)
|
||||
resp = requests.get(url, stream=True, timeout=300)
|
||||
resp.raise_for_status()
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
|
||||
with open(tar_path, "wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(data_dir)
|
||||
os.remove(zip_path)
|
||||
with tarfile.open(tar_path, "r") as tf:
|
||||
tf.extractall(data_dir)
|
||||
os.remove(tar_path)
|
||||
|
||||
|
||||
def download_mmlu(data_dir: str):
|
||||
_download_and_extract(MMLU_URL, data_dir)
|
||||
src = os.path.join(data_dir, "test-master", "data")
|
||||
src = os.path.join(data_dir, "data")
|
||||
if os.path.exists(src):
|
||||
for item in os.listdir(src):
|
||||
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
|
||||
shutil.rmtree(os.path.join(data_dir, "test-master"))
|
||||
src_item = os.path.join(src, item)
|
||||
dst_item = os.path.join(data_dir, item)
|
||||
if os.path.exists(dst_item):
|
||||
if os.path.isdir(dst_item):
|
||||
shutil.rmtree(dst_item)
|
||||
else:
|
||||
os.remove(dst_item)
|
||||
os.rename(src_item, dst_item)
|
||||
os.rmdir(src)
|
||||
print(f"MMLU data saved to {data_dir}")
|
||||
|
||||
|
||||
|
|
@ -233,6 +247,7 @@ def main():
|
|||
device = args.device
|
||||
dtype = getattr(torch, args.dtype)
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
subjects = args.subjects or MMLU_SUBJECTS
|
||||
results = {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue