From 438dc103916b49683715a48e83dbc953313172fb Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 3 Jun 2026 11:59:42 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20MMLU=20eval=20=E4=BD=BF=E7=94=A8=20ch?= =?UTF-8?q?at=20template=20=E6=A0=BC=E5=BC=8F=E5=8C=B9=E9=85=8D=20SFT=20?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 原 prompt 为纯文本格式,与 SFT chat template 不匹配导致模型输出随机 - 新增 apply_chat() 将 MMLU prompt 包装为 user/assistant 对话格式 - choice_text 改为单字母(去掉空格前缀)适配模板输出 - 5-shot 时 few-shot 示例作为独立 user/assistant 轮次插入 --- scripts/tools/evaluate_mmlu.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/scripts/tools/evaluate_mmlu.py b/scripts/tools/evaluate_mmlu.py index 80e19c9..7f4e7f5 100644 --- a/scripts/tools/evaluate_mmlu.py +++ b/scripts/tools/evaluate_mmlu.py @@ -157,10 +157,32 @@ def build_prompt( return prompt +def apply_chat( + tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None +) -> str: + """Wrap raw MMLU prompt in the model's chat template format. + + For few-shot, prepend example Q&A pairs as a second user/assistant exchange. + """ + messages = [] + if n_shot > 0 and dev_data: + for item in dev_data[:n_shot]: + q = f"Question: {item['question']}\n" + for k in ("A", "B", "C", "D"): + q += f"{k}. {item[k]}\n" + q += "Answer:" + messages.append({"role": "user", "content": q}) + messages.append({"role": "assistant", "content": item["answer"]}) + messages.append({"role": "user", "content": raw_prompt}) + return tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def choice_logprob( model, tokenizer, context_ids: list[int], choice_letter: str, device: str ) -> float: - choice_text = f" {choice_letter}" + choice_text = choice_letter choice_ids = tokenizer.encode(choice_text, add_special_tokens=False) input_ids = context_ids + choice_ids max_len = model.config.max_len @@ -196,8 +218,11 @@ def evaluate_subject( correct = 0 total = 0 for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False): - prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or []) - context_ids = tokenizer.encode(prompt) + raw_prompt = build_prompt( + item["question"], item, subject, n_shot, dev_data or [] + ) + context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or []) + context_ids = tokenizer.encode(context) scores = { c: choice_logprob(model, tokenizer, context_ids, c, device) for c in ("A", "B", "C", "D")