From 34c6c45bd6604980767eed36bb6d257e0dc2061d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 26 May 2026 20:19:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0=20?= =?UTF-8?q?MMLU=20=E8=AF=84=E6=B5=8B=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持 few-shot (log-likelihood ranking) 与 zero-shot - 自动下载 Hendrycks MMLU 数据集 - --device / --dtype 可配置,默认 GPU bf16 --- scripts/tools/evaluate_mmlu.py | 279 +++++++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 scripts/tools/evaluate_mmlu.py diff --git a/scripts/tools/evaluate_mmlu.py b/scripts/tools/evaluate_mmlu.py new file mode 100644 index 0000000..8bbef5b --- /dev/null +++ b/scripts/tools/evaluate_mmlu.py @@ -0,0 +1,279 @@ +"""MMLU evaluation via log-likelihood ranking.""" + +import argparse +import csv +import json +import os +import shutil +import urllib.request +import zipfile + +import torch +import torch.nn.functional as F +import tqdm + +from astrai.model import AutoModel +from astrai.tokenize import AutoTokenizer + +MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip" +MMLU_SUBJECTS = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", +] + + +def _download_and_extract(url: str, data_dir: str): + zip_path = os.path.join(data_dir, "mmlu.zip") + os.makedirs(data_dir, exist_ok=True) + print(f"Downloading MMLU data from {url}...") + urllib.request.urlretrieve(url, zip_path) + print("Extracting...") + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(data_dir) + os.remove(zip_path) + + +def download_mmlu(data_dir: str): + _download_and_extract(MMLU_URL, data_dir) + src = os.path.join(data_dir, "test-master", "data") + if os.path.exists(src): + for item in os.listdir(src): + os.rename(os.path.join(src, item), os.path.join(data_dir, item)) + shutil.rmtree(os.path.join(data_dir, "test-master")) + print(f"MMLU data saved to {data_dir}") + + +def _strip_prefix(text: str, prefix: str) -> str: + if text.startswith(prefix): + return text[len(prefix) :].strip() + return text + + +def load_csv(path: str) -> list[dict]: + data = [] + with open(path, "r", encoding="utf-8") as f: + for row in csv.reader(f): + if len(row) < 6: + continue + if row[0].strip().lower() == "question": + continue + data.append( + { + "question": row[0].strip(), + "A": _strip_prefix(row[1].strip(), "A)"), + "B": _strip_prefix(row[2].strip(), "B)"), + "C": _strip_prefix(row[3].strip(), "C)"), + "D": _strip_prefix(row[4].strip(), "D)"), + "answer": row[5].strip(), + } + ) + return data + + +def build_prompt( + question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict] +) -> str: + prompt = "" + if n_shot > 0 and dev_data: + prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n" + for item in dev_data[:n_shot]: + prompt += f"Question: {item['question']}\n" + for k in ("A", "B", "C", "D"): + prompt += f"{k}. {item[k]}\n" + prompt += f"Answer: {item['answer']}\n\n" + prompt += f"Question: {question}\n" + for k in ("A", "B", "C", "D"): + prompt += f"{k}. {choices[k]}\n" + prompt += "Answer:" + return prompt + + +def choice_logprob( + model, tokenizer, context_ids: list[int], choice_letter: str, device: str +) -> float: + choice_text = f" {choice_letter}" + choice_ids = tokenizer.encode(choice_text, add_special_tokens=False) + input_ids = context_ids + choice_ids + max_len = model.config.max_len + if len(input_ids) > max_len: + overflow = len(input_ids) - max_len + input_ids = input_ids[overflow:] + ctx_len = len(input_ids) - len(choice_ids) + else: + ctx_len = len(context_ids) + + input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long) + with torch.inference_mode(): + logits = model(input_tensor)["logits"][0] + + score = 0.0 + for i, tid in enumerate(choice_ids): + pos = ctx_len - 1 + i + if pos >= len(logits): + break + score += F.log_softmax(logits[pos], dim=-1)[tid].item() + return score + + +def evaluate_subject( + model, + tokenizer, + subject: str, + test_data: list[dict], + dev_data: list[dict] | None, + device: str, + n_shot: int, +) -> tuple[float, int, int]: + correct = 0 + total = 0 + for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False): + prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or []) + context_ids = tokenizer.encode(prompt) + scores = { + c: choice_logprob(model, tokenizer, context_ids, c, device) + for c in ("A", "B", "C", "D") + } + if max(scores, key=scores.get) == item["answer"]: + correct += 1 + total += 1 + return correct / total, correct, total + + +def main(): + parser = argparse.ArgumentParser(description="MMLU evaluation") + parser.add_argument( + "--model_dir", type=str, default="./params", help="Model directory" + ) + parser.add_argument( + "--data_dir", type=str, default="./mmlu_data", help="MMLU data directory" + ) + parser.add_argument("--download", action="store_true", help="Download MMLU data") + parser.add_argument( + "--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)" + ) + parser.add_argument( + "--subjects", type=str, nargs="+", help="Specific subjects (default: all)" + ) + parser.add_argument("--output", type=str, help="Output JSON path") + parser.add_argument("--split", type=str, default="test", choices=["test", "val"]) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16" if torch.cuda.is_available() else "float32", + help="Torch dtype", + ) + args = parser.parse_args() + + if args.download or not os.path.exists(args.data_dir): + download_mmlu(args.data_dir) + + model = AutoModel.from_pretrained(args.model_dir) + tokenizer = AutoTokenizer.from_pretrained(args.model_dir) + device = args.device + dtype = getattr(torch, args.dtype) + model.to(device=device, dtype=dtype) + + subjects = args.subjects or MMLU_SUBJECTS + results = {} + total_correct = 0 + total_questions = 0 + + for subject in subjects: + dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv") + test_path = os.path.join( + args.data_dir, args.split, f"{subject}_{args.split}.csv" + ) + + if not os.path.exists(test_path): + print(f" Skipping {subject}: test file not found") + continue + + dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None + test_data = load_csv(test_path) + + acc, corr, tot = evaluate_subject( + model, tokenizer, subject, test_data, dev_data, device, args.n_shot + ) + results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot} + total_correct += corr + total_questions += tot + print(f" {subject:40s} {acc:.2%} ({corr}/{tot})") + + overall = total_correct / total_questions if total_questions else 0 + print(f"\n{'=' * 70}") + print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})") + results["_overall"] = { + "accuracy": round(overall, 4), + "correct": total_correct, + "total": total_questions, + } + + if args.output: + with open(args.output, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2) + print(f"Results saved to {args.output}") + + +if __name__ == "__main__": + main()