diff --git a/scripts/tools/evaluate_mmlu.py b/scripts/tools/evaluate_mmlu.py index fb9321b..80e19c9 100644 --- a/scripts/tools/evaluate_mmlu.py +++ b/scripts/tools/evaluate_mmlu.py @@ -5,9 +5,9 @@ import csv import json import os import shutil -import urllib.request -import zipfile +import tarfile +import requests import torch import torch.nn.functional as F import tqdm @@ -15,7 +15,7 @@ import tqdm from astrai.model import AutoModel from astrai.tokenize import AutoTokenizer -MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip" +MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" MMLU_SUBJECTS = [ "abstract_algebra", "anatomy", @@ -78,23 +78,37 @@ MMLU_SUBJECTS = [ def _download_and_extract(url: str, data_dir: str): - zip_path = os.path.join(data_dir, "mmlu.zip") + tar_path = os.path.join(data_dir, "data.tar") os.makedirs(data_dir, exist_ok=True) print(f"Downloading MMLU data from {url}...") - urllib.request.urlretrieve(url, zip_path) + resp = requests.get(url, stream=True, timeout=300) + resp.raise_for_status() + total = int(resp.headers.get("content-length", 0)) + with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar: + with open(tar_path, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + bar.update(len(chunk)) print("Extracting...") - with zipfile.ZipFile(zip_path, "r") as zf: - zf.extractall(data_dir) - os.remove(zip_path) + with tarfile.open(tar_path, "r") as tf: + tf.extractall(data_dir) + os.remove(tar_path) def download_mmlu(data_dir: str): _download_and_extract(MMLU_URL, data_dir) - src = os.path.join(data_dir, "test-master", "data") + src = os.path.join(data_dir, "data") if os.path.exists(src): for item in os.listdir(src): - os.rename(os.path.join(src, item), os.path.join(data_dir, item)) - shutil.rmtree(os.path.join(data_dir, "test-master")) + src_item = os.path.join(src, item) + dst_item = os.path.join(data_dir, item) + if os.path.exists(dst_item): + if os.path.isdir(dst_item): + shutil.rmtree(dst_item) + else: + os.remove(dst_item) + os.rename(src_item, dst_item) + os.rmdir(src) print(f"MMLU data saved to {data_dir}") @@ -233,6 +247,7 @@ def main(): device = args.device dtype = getattr(torch, args.dtype) model.to(device=device, dtype=dtype) + model.eval() subjects = args.subjects or MMLU_SUBJECTS results = {}