From a923e0a23a6965538ae944186ccdd2ad763ac330 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 30 May 2026 16:51:24 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20=E4=BF=AE=E5=A4=8D=20MMLU=20=E8=AF=84?= =?UTF-8?q?=E6=B5=8B=E8=84=9A=E6=9C=AC=E6=95=B0=E6=8D=AE=E6=BA=90=E5=92=8C?= =?UTF-8?q?=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 数据源改为 Berkeley data.tar(GitHub zip 不含数据文件) - urllib 替换为 requests,支持代理下载 - zip 解压替换为 tar,增加目录 flatten 逻辑 - 添加 model.eval() 确保推理模式正确 --- scripts/tools/evaluate_mmlu.py | 37 ++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/scripts/tools/evaluate_mmlu.py b/scripts/tools/evaluate_mmlu.py index fb9321b..80e19c9 100644 --- a/scripts/tools/evaluate_mmlu.py +++ b/scripts/tools/evaluate_mmlu.py @@ -5,9 +5,9 @@ import csv import json import os import shutil -import urllib.request -import zipfile +import tarfile +import requests import torch import torch.nn.functional as F import tqdm @@ -15,7 +15,7 @@ import tqdm from astrai.model import AutoModel from astrai.tokenize import AutoTokenizer -MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip" +MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" MMLU_SUBJECTS = [ "abstract_algebra", "anatomy", @@ -78,23 +78,37 @@ MMLU_SUBJECTS = [ def _download_and_extract(url: str, data_dir: str): - zip_path = os.path.join(data_dir, "mmlu.zip") + tar_path = os.path.join(data_dir, "data.tar") os.makedirs(data_dir, exist_ok=True) print(f"Downloading MMLU data from {url}...") - urllib.request.urlretrieve(url, zip_path) + resp = requests.get(url, stream=True, timeout=300) + resp.raise_for_status() + total = int(resp.headers.get("content-length", 0)) + with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar: + with open(tar_path, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + bar.update(len(chunk)) print("Extracting...") - with zipfile.ZipFile(zip_path, "r") as zf: - zf.extractall(data_dir) - os.remove(zip_path) + with tarfile.open(tar_path, "r") as tf: + tf.extractall(data_dir) + os.remove(tar_path) def download_mmlu(data_dir: str): _download_and_extract(MMLU_URL, data_dir) - src = os.path.join(data_dir, "test-master", "data") + src = os.path.join(data_dir, "data") if os.path.exists(src): for item in os.listdir(src): - os.rename(os.path.join(src, item), os.path.join(data_dir, item)) - shutil.rmtree(os.path.join(data_dir, "test-master")) + src_item = os.path.join(src, item) + dst_item = os.path.join(data_dir, item) + if os.path.exists(dst_item): + if os.path.isdir(dst_item): + shutil.rmtree(dst_item) + else: + os.remove(dst_item) + os.rename(src_item, dst_item) + os.rmdir(src) print(f"MMLU data saved to {data_dir}") @@ -233,6 +247,7 @@ def main(): device = args.device dtype = getattr(torch, args.dtype) model.to(device=device, dtype=dtype) + model.eval() subjects = args.subjects or MMLU_SUBJECTS results = {}