fix : 修复 MMLU 评测脚本数据源和依赖

- 数据源改为 Berkeley data.tar(GitHub zip 不含数据文件)
- urllib 替换为 requests,支持代理下载
- zip 解压替换为 tar,增加目录 flatten 逻辑
- 添加 model.eval() 确保推理模式正确
This commit is contained in:
ViperEkura 2026-05-30 16:51:24 +08:00
parent f521a30b22
commit a923e0a23a
1 changed files with 26 additions and 11 deletions

View File

@ -5,9 +5,9 @@ import csv
import json
import os
import shutil
import urllib.request
import zipfile
import tarfile
import requests
import torch
import torch.nn.functional as F
import tqdm
@ -15,7 +15,7 @@ import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
@ -78,23 +78,37 @@ MMLU_SUBJECTS = [
def _download_and_extract(url: str, data_dir: str):
zip_path = os.path.join(data_dir, "mmlu.zip")
tar_path = os.path.join(data_dir, "data.tar")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
urllib.request.urlretrieve(url, zip_path)
resp = requests.get(url, stream=True, timeout=300)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
with open(tar_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
print("Extracting...")
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(data_dir)
os.remove(zip_path)
with tarfile.open(tar_path, "r") as tf:
tf.extractall(data_dir)
os.remove(tar_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "test-master", "data")
src = os.path.join(data_dir, "data")
if os.path.exists(src):
for item in os.listdir(src):
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
shutil.rmtree(os.path.join(data_dir, "test-master"))
src_item = os.path.join(src, item)
dst_item = os.path.join(data_dir, item)
if os.path.exists(dst_item):
if os.path.isdir(dst_item):
shutil.rmtree(dst_item)
else:
os.remove(dst_item)
os.rename(src_item, dst_item)
os.rmdir(src)
print(f"MMLU data saved to {data_dir}")
@ -233,6 +247,7 @@ def main():
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
model.eval()
subjects = args.subjects or MMLU_SUBJECTS
results = {}