feat: 初步实现 MMLU 评测脚本
- 支持 few-shot (log-likelihood ranking) 与 zero-shot - 自动下载 Hendrycks MMLU 数据集 - --device / --dtype 可配置,默认 GPU bf16
This commit is contained in:
parent
e9def84ce7
commit
34c6c45bd6
|
|
@ -0,0 +1,279 @@
|
|||
"""MMLU evaluation via log-likelihood ranking."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
|
||||
MMLU_SUBJECTS = [
|
||||
"abstract_algebra",
|
||||
"anatomy",
|
||||
"astronomy",
|
||||
"business_ethics",
|
||||
"clinical_knowledge",
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_medicine",
|
||||
"college_physics",
|
||||
"computer_security",
|
||||
"conceptual_physics",
|
||||
"econometrics",
|
||||
"electrical_engineering",
|
||||
"elementary_mathematics",
|
||||
"formal_logic",
|
||||
"global_facts",
|
||||
"high_school_biology",
|
||||
"high_school_chemistry",
|
||||
"high_school_computer_science",
|
||||
"high_school_european_history",
|
||||
"high_school_geography",
|
||||
"high_school_government_and_politics",
|
||||
"high_school_macroeconomics",
|
||||
"high_school_mathematics",
|
||||
"high_school_microeconomics",
|
||||
"high_school_physics",
|
||||
"high_school_psychology",
|
||||
"high_school_statistics",
|
||||
"high_school_us_history",
|
||||
"high_school_world_history",
|
||||
"human_aging",
|
||||
"human_sexuality",
|
||||
"international_law",
|
||||
"jurisprudence",
|
||||
"logical_fallacies",
|
||||
"machine_learning",
|
||||
"management",
|
||||
"marketing",
|
||||
"medical_genetics",
|
||||
"miscellaneous",
|
||||
"moral_disputes",
|
||||
"moral_scenarios",
|
||||
"nutrition",
|
||||
"philosophy",
|
||||
"prehistory",
|
||||
"professional_accounting",
|
||||
"professional_law",
|
||||
"professional_medicine",
|
||||
"professional_psychology",
|
||||
"public_relations",
|
||||
"security_studies",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
"virology",
|
||||
"world_religions",
|
||||
]
|
||||
|
||||
|
||||
def _download_and_extract(url: str, data_dir: str):
|
||||
zip_path = os.path.join(data_dir, "mmlu.zip")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
print(f"Downloading MMLU data from {url}...")
|
||||
urllib.request.urlretrieve(url, zip_path)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(data_dir)
|
||||
os.remove(zip_path)
|
||||
|
||||
|
||||
def download_mmlu(data_dir: str):
|
||||
_download_and_extract(MMLU_URL, data_dir)
|
||||
src = os.path.join(data_dir, "test-master", "data")
|
||||
if os.path.exists(src):
|
||||
for item in os.listdir(src):
|
||||
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
|
||||
shutil.rmtree(os.path.join(data_dir, "test-master"))
|
||||
print(f"MMLU data saved to {data_dir}")
|
||||
|
||||
|
||||
def _strip_prefix(text: str, prefix: str) -> str:
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix) :].strip()
|
||||
return text
|
||||
|
||||
|
||||
def load_csv(path: str) -> list[dict]:
|
||||
data = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for row in csv.reader(f):
|
||||
if len(row) < 6:
|
||||
continue
|
||||
if row[0].strip().lower() == "question":
|
||||
continue
|
||||
data.append(
|
||||
{
|
||||
"question": row[0].strip(),
|
||||
"A": _strip_prefix(row[1].strip(), "A)"),
|
||||
"B": _strip_prefix(row[2].strip(), "B)"),
|
||||
"C": _strip_prefix(row[3].strip(), "C)"),
|
||||
"D": _strip_prefix(row[4].strip(), "D)"),
|
||||
"answer": row[5].strip(),
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def build_prompt(
|
||||
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
|
||||
) -> str:
|
||||
prompt = ""
|
||||
if n_shot > 0 and dev_data:
|
||||
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
|
||||
for item in dev_data[:n_shot]:
|
||||
prompt += f"Question: {item['question']}\n"
|
||||
for k in ("A", "B", "C", "D"):
|
||||
prompt += f"{k}. {item[k]}\n"
|
||||
prompt += f"Answer: {item['answer']}\n\n"
|
||||
prompt += f"Question: {question}\n"
|
||||
for k in ("A", "B", "C", "D"):
|
||||
prompt += f"{k}. {choices[k]}\n"
|
||||
prompt += "Answer:"
|
||||
return prompt
|
||||
|
||||
|
||||
def choice_logprob(
|
||||
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
|
||||
) -> float:
|
||||
choice_text = f" {choice_letter}"
|
||||
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
|
||||
input_ids = context_ids + choice_ids
|
||||
max_len = model.config.max_len
|
||||
if len(input_ids) > max_len:
|
||||
overflow = len(input_ids) - max_len
|
||||
input_ids = input_ids[overflow:]
|
||||
ctx_len = len(input_ids) - len(choice_ids)
|
||||
else:
|
||||
ctx_len = len(context_ids)
|
||||
|
||||
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
|
||||
with torch.inference_mode():
|
||||
logits = model(input_tensor)["logits"][0]
|
||||
|
||||
score = 0.0
|
||||
for i, tid in enumerate(choice_ids):
|
||||
pos = ctx_len - 1 + i
|
||||
if pos >= len(logits):
|
||||
break
|
||||
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
|
||||
return score
|
||||
|
||||
|
||||
def evaluate_subject(
|
||||
model,
|
||||
tokenizer,
|
||||
subject: str,
|
||||
test_data: list[dict],
|
||||
dev_data: list[dict] | None,
|
||||
device: str,
|
||||
n_shot: int,
|
||||
) -> tuple[float, int, int]:
|
||||
correct = 0
|
||||
total = 0
|
||||
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
|
||||
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or [])
|
||||
context_ids = tokenizer.encode(prompt)
|
||||
scores = {
|
||||
c: choice_logprob(model, tokenizer, context_ids, c, device)
|
||||
for c in ("A", "B", "C", "D")
|
||||
}
|
||||
if max(scores, key=scores.get) == item["answer"]:
|
||||
correct += 1
|
||||
total += 1
|
||||
return correct / total, correct, total
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
||||
parser.add_argument(
|
||||
"--model_dir", type=str, default="./params", help="Model directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
||||
)
|
||||
parser.add_argument("--download", action="store_true", help="Download MMLU data")
|
||||
parser.add_argument(
|
||||
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output JSON path")
|
||||
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16" if torch.cuda.is_available() else "float32",
|
||||
help="Torch dtype",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.download or not os.path.exists(args.data_dir):
|
||||
download_mmlu(args.data_dir)
|
||||
|
||||
model = AutoModel.from_pretrained(args.model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
||||
device = args.device
|
||||
dtype = getattr(torch, args.dtype)
|
||||
model.to(device=device, dtype=dtype)
|
||||
|
||||
subjects = args.subjects or MMLU_SUBJECTS
|
||||
results = {}
|
||||
total_correct = 0
|
||||
total_questions = 0
|
||||
|
||||
for subject in subjects:
|
||||
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
|
||||
test_path = os.path.join(
|
||||
args.data_dir, args.split, f"{subject}_{args.split}.csv"
|
||||
)
|
||||
|
||||
if not os.path.exists(test_path):
|
||||
print(f" Skipping {subject}: test file not found")
|
||||
continue
|
||||
|
||||
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
|
||||
test_data = load_csv(test_path)
|
||||
|
||||
acc, corr, tot = evaluate_subject(
|
||||
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
|
||||
)
|
||||
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
|
||||
total_correct += corr
|
||||
total_questions += tot
|
||||
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
|
||||
|
||||
overall = total_correct / total_questions if total_questions else 0
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
|
||||
results["_overall"] = {
|
||||
"accuracy": round(overall, 4),
|
||||
"correct": total_correct,
|
||||
"total": total_questions,
|
||||
}
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"Results saved to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue