fix : MMLU eval 使用 chat template 格式匹配 SFT 训练数据

- 原 prompt 为纯文本格式,与 SFT chat template 不匹配导致模型输出随机
- 新增 apply_chat() 将 MMLU prompt 包装为 user/assistant 对话格式
- choice_text 改为单字母(去掉空格前缀)适配模板输出
- 5-shot 时 few-shot 示例作为独立 user/assistant 轮次插入
This commit is contained in:
ViperEkura 2026-06-03 11:59:42 +08:00
parent 615ba5d8ef
commit 438dc10391
1 changed files with 28 additions and 3 deletions

View File

@ -157,10 +157,32 @@ def build_prompt(
return prompt return prompt
def apply_chat(
tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None
) -> str:
"""Wrap raw MMLU prompt in the model's chat template format.
For few-shot, prepend example Q&A pairs as a second user/assistant exchange.
"""
messages = []
if n_shot > 0 and dev_data:
for item in dev_data[:n_shot]:
q = f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
q += f"{k}. {item[k]}\n"
q += "Answer:"
messages.append({"role": "user", "content": q})
messages.append({"role": "assistant", "content": item["answer"]})
messages.append({"role": "user", "content": raw_prompt})
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def choice_logprob( def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float: ) -> float:
choice_text = f" {choice_letter}" choice_text = choice_letter
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False) choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids input_ids = context_ids + choice_ids
max_len = model.config.max_len max_len = model.config.max_len
@ -196,8 +218,11 @@ def evaluate_subject(
correct = 0 correct = 0
total = 0 total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False): for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or []) raw_prompt = build_prompt(
context_ids = tokenizer.encode(prompt) item["question"], item, subject, n_shot, dev_data or []
)
context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or [])
context_ids = tokenizer.encode(context)
scores = { scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device) c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D") for c in ("A", "B", "C", "D")