fix : 修复 MMLU 评测脚本数据源和依赖
- 数据源改为 Berkeley data.tar(GitHub zip 不含数据文件) - urllib 替换为 requests,支持代理下载 - zip 解压替换为 tar,增加目录 flatten 逻辑 - 添加 model.eval() 确保推理模式正确
This commit is contained in:
parent
f521a30b22
commit
a923e0a23a
|
|
@ -5,9 +5,9 @@ import csv
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import urllib.request
|
import tarfile
|
||||||
import zipfile
|
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
@ -15,7 +15,7 @@ import tqdm
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
|
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
|
||||||
MMLU_SUBJECTS = [
|
MMLU_SUBJECTS = [
|
||||||
"abstract_algebra",
|
"abstract_algebra",
|
||||||
"anatomy",
|
"anatomy",
|
||||||
|
|
@ -78,23 +78,37 @@ MMLU_SUBJECTS = [
|
||||||
|
|
||||||
|
|
||||||
def _download_and_extract(url: str, data_dir: str):
|
def _download_and_extract(url: str, data_dir: str):
|
||||||
zip_path = os.path.join(data_dir, "mmlu.zip")
|
tar_path = os.path.join(data_dir, "data.tar")
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
print(f"Downloading MMLU data from {url}...")
|
print(f"Downloading MMLU data from {url}...")
|
||||||
urllib.request.urlretrieve(url, zip_path)
|
resp = requests.get(url, stream=True, timeout=300)
|
||||||
|
resp.raise_for_status()
|
||||||
|
total = int(resp.headers.get("content-length", 0))
|
||||||
|
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
|
||||||
|
with open(tar_path, "wb") as f:
|
||||||
|
for chunk in resp.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
print("Extracting...")
|
print("Extracting...")
|
||||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
with tarfile.open(tar_path, "r") as tf:
|
||||||
zf.extractall(data_dir)
|
tf.extractall(data_dir)
|
||||||
os.remove(zip_path)
|
os.remove(tar_path)
|
||||||
|
|
||||||
|
|
||||||
def download_mmlu(data_dir: str):
|
def download_mmlu(data_dir: str):
|
||||||
_download_and_extract(MMLU_URL, data_dir)
|
_download_and_extract(MMLU_URL, data_dir)
|
||||||
src = os.path.join(data_dir, "test-master", "data")
|
src = os.path.join(data_dir, "data")
|
||||||
if os.path.exists(src):
|
if os.path.exists(src):
|
||||||
for item in os.listdir(src):
|
for item in os.listdir(src):
|
||||||
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
|
src_item = os.path.join(src, item)
|
||||||
shutil.rmtree(os.path.join(data_dir, "test-master"))
|
dst_item = os.path.join(data_dir, item)
|
||||||
|
if os.path.exists(dst_item):
|
||||||
|
if os.path.isdir(dst_item):
|
||||||
|
shutil.rmtree(dst_item)
|
||||||
|
else:
|
||||||
|
os.remove(dst_item)
|
||||||
|
os.rename(src_item, dst_item)
|
||||||
|
os.rmdir(src)
|
||||||
print(f"MMLU data saved to {data_dir}")
|
print(f"MMLU data saved to {data_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -233,6 +247,7 @@ def main():
|
||||||
device = args.device
|
device = args.device
|
||||||
dtype = getattr(torch, args.dtype)
|
dtype = getattr(torch, args.dtype)
|
||||||
model.to(device=device, dtype=dtype)
|
model.to(device=device, dtype=dtype)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
subjects = args.subjects or MMLU_SUBJECTS
|
subjects = args.subjects or MMLU_SUBJECTS
|
||||||
results = {}
|
results = {}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue