diff --git a/scripts/tools/evaluate_mmlu.py b/scripts/tools/evaluate_mmlu.py index 80e19c9..7f4e7f5 100644 --- a/scripts/tools/evaluate_mmlu.py +++ b/scripts/tools/evaluate_mmlu.py @@ -157,10 +157,32 @@ def build_prompt( return prompt +def apply_chat( + tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None +) -> str: + """Wrap raw MMLU prompt in the model's chat template format. + + For few-shot, prepend example Q&A pairs as a second user/assistant exchange. + """ + messages = [] + if n_shot > 0 and dev_data: + for item in dev_data[:n_shot]: + q = f"Question: {item['question']}\n" + for k in ("A", "B", "C", "D"): + q += f"{k}. {item[k]}\n" + q += "Answer:" + messages.append({"role": "user", "content": q}) + messages.append({"role": "assistant", "content": item["answer"]}) + messages.append({"role": "user", "content": raw_prompt}) + return tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def choice_logprob( model, tokenizer, context_ids: list[int], choice_letter: str, device: str ) -> float: - choice_text = f" {choice_letter}" + choice_text = choice_letter choice_ids = tokenizer.encode(choice_text, add_special_tokens=False) input_ids = context_ids + choice_ids max_len = model.config.max_len @@ -196,8 +218,11 @@ def evaluate_subject( correct = 0 total = 0 for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False): - prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or []) - context_ids = tokenizer.encode(prompt) + raw_prompt = build_prompt( + item["question"], item, subject, n_shot, dev_data or [] + ) + context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or []) + context_ids = tokenizer.encode(context) scores = { c: choice_logprob(model, tokenizer, context_ids, c, device) for c in ("A", "B", "C", "D")