280 lines
8.4 KiB
Python
280 lines
8.4 KiB
Python
"""MMLU evaluation via log-likelihood ranking."""
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import os
|
|
import shutil
|
|
import urllib.request
|
|
import zipfile
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import tqdm
|
|
|
|
from astrai.model import AutoModel
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
|
|
MMLU_SUBJECTS = [
|
|
"abstract_algebra",
|
|
"anatomy",
|
|
"astronomy",
|
|
"business_ethics",
|
|
"clinical_knowledge",
|
|
"college_biology",
|
|
"college_chemistry",
|
|
"college_computer_science",
|
|
"college_mathematics",
|
|
"college_medicine",
|
|
"college_physics",
|
|
"computer_security",
|
|
"conceptual_physics",
|
|
"econometrics",
|
|
"electrical_engineering",
|
|
"elementary_mathematics",
|
|
"formal_logic",
|
|
"global_facts",
|
|
"high_school_biology",
|
|
"high_school_chemistry",
|
|
"high_school_computer_science",
|
|
"high_school_european_history",
|
|
"high_school_geography",
|
|
"high_school_government_and_politics",
|
|
"high_school_macroeconomics",
|
|
"high_school_mathematics",
|
|
"high_school_microeconomics",
|
|
"high_school_physics",
|
|
"high_school_psychology",
|
|
"high_school_statistics",
|
|
"high_school_us_history",
|
|
"high_school_world_history",
|
|
"human_aging",
|
|
"human_sexuality",
|
|
"international_law",
|
|
"jurisprudence",
|
|
"logical_fallacies",
|
|
"machine_learning",
|
|
"management",
|
|
"marketing",
|
|
"medical_genetics",
|
|
"miscellaneous",
|
|
"moral_disputes",
|
|
"moral_scenarios",
|
|
"nutrition",
|
|
"philosophy",
|
|
"prehistory",
|
|
"professional_accounting",
|
|
"professional_law",
|
|
"professional_medicine",
|
|
"professional_psychology",
|
|
"public_relations",
|
|
"security_studies",
|
|
"sociology",
|
|
"us_foreign_policy",
|
|
"virology",
|
|
"world_religions",
|
|
]
|
|
|
|
|
|
def _download_and_extract(url: str, data_dir: str):
|
|
zip_path = os.path.join(data_dir, "mmlu.zip")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
print(f"Downloading MMLU data from {url}...")
|
|
urllib.request.urlretrieve(url, zip_path)
|
|
print("Extracting...")
|
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
|
zf.extractall(data_dir)
|
|
os.remove(zip_path)
|
|
|
|
|
|
def download_mmlu(data_dir: str):
|
|
_download_and_extract(MMLU_URL, data_dir)
|
|
src = os.path.join(data_dir, "test-master", "data")
|
|
if os.path.exists(src):
|
|
for item in os.listdir(src):
|
|
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
|
|
shutil.rmtree(os.path.join(data_dir, "test-master"))
|
|
print(f"MMLU data saved to {data_dir}")
|
|
|
|
|
|
def _strip_prefix(text: str, prefix: str) -> str:
|
|
if text.startswith(prefix):
|
|
return text[len(prefix) :].strip()
|
|
return text
|
|
|
|
|
|
def load_csv(path: str) -> list[dict]:
|
|
data = []
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for row in csv.reader(f):
|
|
if len(row) < 6:
|
|
continue
|
|
if row[0].strip().lower() == "question":
|
|
continue
|
|
data.append(
|
|
{
|
|
"question": row[0].strip(),
|
|
"A": _strip_prefix(row[1].strip(), "A)"),
|
|
"B": _strip_prefix(row[2].strip(), "B)"),
|
|
"C": _strip_prefix(row[3].strip(), "C)"),
|
|
"D": _strip_prefix(row[4].strip(), "D)"),
|
|
"answer": row[5].strip(),
|
|
}
|
|
)
|
|
return data
|
|
|
|
|
|
def build_prompt(
|
|
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
|
|
) -> str:
|
|
prompt = ""
|
|
if n_shot > 0 and dev_data:
|
|
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
|
|
for item in dev_data[:n_shot]:
|
|
prompt += f"Question: {item['question']}\n"
|
|
for k in ("A", "B", "C", "D"):
|
|
prompt += f"{k}. {item[k]}\n"
|
|
prompt += f"Answer: {item['answer']}\n\n"
|
|
prompt += f"Question: {question}\n"
|
|
for k in ("A", "B", "C", "D"):
|
|
prompt += f"{k}. {choices[k]}\n"
|
|
prompt += "Answer:"
|
|
return prompt
|
|
|
|
|
|
def choice_logprob(
|
|
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
|
|
) -> float:
|
|
choice_text = f" {choice_letter}"
|
|
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
|
|
input_ids = context_ids + choice_ids
|
|
max_len = model.config.max_len
|
|
if len(input_ids) > max_len:
|
|
overflow = len(input_ids) - max_len
|
|
input_ids = input_ids[overflow:]
|
|
ctx_len = len(input_ids) - len(choice_ids)
|
|
else:
|
|
ctx_len = len(context_ids)
|
|
|
|
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
|
|
with torch.inference_mode():
|
|
logits = model(input_tensor)["logits"][0]
|
|
|
|
score = 0.0
|
|
for i, tid in enumerate(choice_ids):
|
|
pos = ctx_len - 1 + i
|
|
if pos >= len(logits):
|
|
break
|
|
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
|
|
return score
|
|
|
|
|
|
def evaluate_subject(
|
|
model,
|
|
tokenizer,
|
|
subject: str,
|
|
test_data: list[dict],
|
|
dev_data: list[dict] | None,
|
|
device: str,
|
|
n_shot: int,
|
|
) -> tuple[float, int, int]:
|
|
correct = 0
|
|
total = 0
|
|
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
|
|
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or [])
|
|
context_ids = tokenizer.encode(prompt)
|
|
scores = {
|
|
c: choice_logprob(model, tokenizer, context_ids, c, device)
|
|
for c in ("A", "B", "C", "D")
|
|
}
|
|
if max(scores, key=scores.get) == item["answer"]:
|
|
correct += 1
|
|
total += 1
|
|
return correct / total, correct, total
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
|
parser.add_argument(
|
|
"--param_path", type=str, default="./params", help="Model directory"
|
|
)
|
|
parser.add_argument(
|
|
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
|
)
|
|
parser.add_argument("--download", action="store_true", help="Download MMLU data")
|
|
parser.add_argument(
|
|
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
|
|
)
|
|
parser.add_argument(
|
|
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
|
|
)
|
|
parser.add_argument("--output", type=str, help="Output JSON path")
|
|
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="Device",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
default="bfloat16" if torch.cuda.is_available() else "float32",
|
|
help="Torch dtype",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.download or not os.path.exists(args.data_dir):
|
|
download_mmlu(args.data_dir)
|
|
|
|
model = AutoModel.from_pretrained(args.param_path)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
|
device = args.device
|
|
dtype = getattr(torch, args.dtype)
|
|
model.to(device=device, dtype=dtype)
|
|
|
|
subjects = args.subjects or MMLU_SUBJECTS
|
|
results = {}
|
|
total_correct = 0
|
|
total_questions = 0
|
|
|
|
for subject in subjects:
|
|
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
|
|
test_path = os.path.join(
|
|
args.data_dir, args.split, f"{subject}_{args.split}.csv"
|
|
)
|
|
|
|
if not os.path.exists(test_path):
|
|
print(f" Skipping {subject}: test file not found")
|
|
continue
|
|
|
|
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
|
|
test_data = load_csv(test_path)
|
|
|
|
acc, corr, tot = evaluate_subject(
|
|
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
|
|
)
|
|
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
|
|
total_correct += corr
|
|
total_questions += tot
|
|
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
|
|
|
|
overall = total_correct / total_questions if total_questions else 0
|
|
print(f"\n{'=' * 70}")
|
|
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
|
|
results["_overall"] = {
|
|
"accuracy": round(overall, 4),
|
|
"correct": total_correct,
|
|
"total": total_questions,
|
|
}
|
|
|
|
if args.output:
|
|
with open(args.output, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2)
|
|
print(f"Results saved to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|