"""MMLU evaluation via log-likelihood ranking.""" import argparse import csv import json import os import shutil import urllib.request import zipfile import torch import torch.nn.functional as F import tqdm from astrai.model import AutoModel from astrai.tokenize import AutoTokenizer MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip" MMLU_SUBJECTS = [ "abstract_algebra", "anatomy", "astronomy", "business_ethics", "clinical_knowledge", "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", "college_medicine", "college_physics", "computer_security", "conceptual_physics", "econometrics", "electrical_engineering", "elementary_mathematics", "formal_logic", "global_facts", "high_school_biology", "high_school_chemistry", "high_school_computer_science", "high_school_european_history", "high_school_geography", "high_school_government_and_politics", "high_school_macroeconomics", "high_school_mathematics", "high_school_microeconomics", "high_school_physics", "high_school_psychology", "high_school_statistics", "high_school_us_history", "high_school_world_history", "human_aging", "human_sexuality", "international_law", "jurisprudence", "logical_fallacies", "machine_learning", "management", "marketing", "medical_genetics", "miscellaneous", "moral_disputes", "moral_scenarios", "nutrition", "philosophy", "prehistory", "professional_accounting", "professional_law", "professional_medicine", "professional_psychology", "public_relations", "security_studies", "sociology", "us_foreign_policy", "virology", "world_religions", ] def _download_and_extract(url: str, data_dir: str): zip_path = os.path.join(data_dir, "mmlu.zip") os.makedirs(data_dir, exist_ok=True) print(f"Downloading MMLU data from {url}...") urllib.request.urlretrieve(url, zip_path) print("Extracting...") with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(data_dir) os.remove(zip_path) def download_mmlu(data_dir: str): _download_and_extract(MMLU_URL, data_dir) src = os.path.join(data_dir, "test-master", "data") if os.path.exists(src): for item in os.listdir(src): os.rename(os.path.join(src, item), os.path.join(data_dir, item)) shutil.rmtree(os.path.join(data_dir, "test-master")) print(f"MMLU data saved to {data_dir}") def _strip_prefix(text: str, prefix: str) -> str: if text.startswith(prefix): return text[len(prefix) :].strip() return text def load_csv(path: str) -> list[dict]: data = [] with open(path, "r", encoding="utf-8") as f: for row in csv.reader(f): if len(row) < 6: continue if row[0].strip().lower() == "question": continue data.append( { "question": row[0].strip(), "A": _strip_prefix(row[1].strip(), "A)"), "B": _strip_prefix(row[2].strip(), "B)"), "C": _strip_prefix(row[3].strip(), "C)"), "D": _strip_prefix(row[4].strip(), "D)"), "answer": row[5].strip(), } ) return data def build_prompt( question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict] ) -> str: prompt = "" if n_shot > 0 and dev_data: prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n" for item in dev_data[:n_shot]: prompt += f"Question: {item['question']}\n" for k in ("A", "B", "C", "D"): prompt += f"{k}. {item[k]}\n" prompt += f"Answer: {item['answer']}\n\n" prompt += f"Question: {question}\n" for k in ("A", "B", "C", "D"): prompt += f"{k}. {choices[k]}\n" prompt += "Answer:" return prompt def choice_logprob( model, tokenizer, context_ids: list[int], choice_letter: str, device: str ) -> float: choice_text = f" {choice_letter}" choice_ids = tokenizer.encode(choice_text, add_special_tokens=False) input_ids = context_ids + choice_ids max_len = model.config.max_len if len(input_ids) > max_len: overflow = len(input_ids) - max_len input_ids = input_ids[overflow:] ctx_len = len(input_ids) - len(choice_ids) else: ctx_len = len(context_ids) input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long) with torch.inference_mode(): logits = model(input_tensor)["logits"][0] score = 0.0 for i, tid in enumerate(choice_ids): pos = ctx_len - 1 + i if pos >= len(logits): break score += F.log_softmax(logits[pos], dim=-1)[tid].item() return score def evaluate_subject( model, tokenizer, subject: str, test_data: list[dict], dev_data: list[dict] | None, device: str, n_shot: int, ) -> tuple[float, int, int]: correct = 0 total = 0 for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False): prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or []) context_ids = tokenizer.encode(prompt) scores = { c: choice_logprob(model, tokenizer, context_ids, c, device) for c in ("A", "B", "C", "D") } if max(scores, key=scores.get) == item["answer"]: correct += 1 total += 1 return correct / total, correct, total def main(): parser = argparse.ArgumentParser(description="MMLU evaluation") parser.add_argument( "--param_path", type=str, default="./params", help="Model directory" ) parser.add_argument( "--data_dir", type=str, default="./mmlu_data", help="MMLU data directory" ) parser.add_argument("--download", action="store_true", help="Download MMLU data") parser.add_argument( "--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)" ) parser.add_argument( "--subjects", type=str, nargs="+", help="Specific subjects (default: all)" ) parser.add_argument("--output", type=str, help="Output JSON path") parser.add_argument("--split", type=str, default="test", choices=["test", "val"]) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device", ) parser.add_argument( "--dtype", type=str, default="bfloat16" if torch.cuda.is_available() else "float32", help="Torch dtype", ) args = parser.parse_args() if args.download or not os.path.exists(args.data_dir): download_mmlu(args.data_dir) model = AutoModel.from_pretrained(args.param_path) tokenizer = AutoTokenizer.from_pretrained(args.param_path) device = args.device dtype = getattr(torch, args.dtype) model.to(device=device, dtype=dtype) subjects = args.subjects or MMLU_SUBJECTS results = {} total_correct = 0 total_questions = 0 for subject in subjects: dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv") test_path = os.path.join( args.data_dir, args.split, f"{subject}_{args.split}.csv" ) if not os.path.exists(test_path): print(f" Skipping {subject}: test file not found") continue dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None test_data = load_csv(test_path) acc, corr, tot = evaluate_subject( model, tokenizer, subject, test_data, dev_data, device, args.n_shot ) results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot} total_correct += corr total_questions += tot print(f" {subject:40s} {acc:.2%} ({corr}/{tot})") overall = total_correct / total_questions if total_questions else 0 print(f"\n{'=' * 70}") print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})") results["_overall"] = { "accuracy": round(overall, 4), "correct": total_correct, "total": total_questions, } if args.output: with open(args.output, "w", encoding="utf-8") as f: json.dump(results, f, indent=2) print(f"Results saved to {args.output}") if __name__ == "__main__": main()