From a62c2e11a2f1514a92b22ecf85bab532b77b81e6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 18 Jun 2026 16:34:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20IFD=20=E9=BB=98=E8=AE=A4=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=20chat=20template=EF=BC=8C=E6=94=AF=E6=8C=81=E8=A3=B8?= =?UTF-8?q?=E6=96=87=E6=9C=AC=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 _compute_ifd_with_template,用 tokenizer chat template 格式化后计算 IFD - 默认开启 chat template,可通过 --no_chat_template 切换回裸拼接 - chat template 缺失时给出 RuntimeError 提示 --- scripts/eval/evaluate_ifd.py | 126 ++++++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 8 deletions(-) diff --git a/scripts/eval/evaluate_ifd.py b/scripts/eval/evaluate_ifd.py index 4fe43ab..7ce49d2 100644 --- a/scripts/eval/evaluate_ifd.py +++ b/scripts/eval/evaluate_ifd.py @@ -15,6 +15,13 @@ Usage:: python scripts/eval/ifd.py --param_path ./params \ --input data.jsonl --output data_with_ifd.jsonl \ --instr_key instruction --resp_key response + +Disable chat template:: + + python scripts/eval/ifd.py --param_path ./params \ + --input data.jsonl --output data_with_ifd.jsonl \ + --instr_key instruction --resp_key response \ + --no_chat_template """ import argparse @@ -35,7 +42,16 @@ def compute_ifd( response: str, device: str, max_len: int = 2048, + use_chat_template: bool = False, ) -> dict: + if use_chat_template: + return _compute_ifd_with_template( + model, tokenizer, instruction, response, device, max_len + ) + return _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len) + + +def _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len) -> dict: instr_ids = tokenizer.encode(instruction) resp_ids = tokenizer.encode(response) @@ -47,7 +63,6 @@ def compute_ifd( "error": "empty response", } - # Truncate instruction if total length exceeds max_len qa_len = len(instr_ids) + len(resp_ids) if qa_len > max_len: overflow = qa_len - max_len @@ -56,24 +71,22 @@ def compute_ifd( instr_len = len(instr_ids) resp_len = len(resp_ids) - # Conditional: instruction + response qa_ids = instr_ids + resp_ids qa_tensor = torch.tensor([qa_ids], device=device, dtype=torch.long) with torch.inference_mode(): - logits_qa = model(qa_tensor)["logits"][0] # [qa_len, vocab] + logits_qa = model(qa_tensor)["logits"][0] - resp_logits = logits_qa[instr_len - 1 : -1] # predict response tokens + resp_logits = logits_qa[instr_len - 1 : -1] resp_targets = torch.tensor(resp_ids, device=device, dtype=torch.long) L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item() - # Unconditional: response alone resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long) with torch.inference_mode(): - logits_resp = model(resp_tensor)["logits"][0] # [resp_len, vocab] + logits_resp = model(resp_tensor)["logits"][0] - unp_logits = logits_resp[:-1] # causal shift + unp_logits = logits_resp[:-1] unp_targets = resp_tensor[0, 1:] L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item() @@ -89,6 +102,83 @@ def compute_ifd( } +def _compute_ifd_with_template( + model, tokenizer, instruction, response, device, max_len +) -> dict: + instr_prefix = tokenizer.apply_chat_template( + [{"role": "user", "content": instruction}], + tokenize=False, + add_generation_prompt=True, + ) + full_text = tokenizer.apply_chat_template( + [ + {"role": "user", "content": instruction}, + {"role": "assistant", "content": response}, + ], + tokenize=False, + add_generation_prompt=False, + ) + + full_ids = tokenizer.encode(full_text) + prefix_ids = tokenizer.encode(instr_prefix) + resp_ids = tokenizer.encode(response) + + if not resp_ids: + return { + "L_cond": None, + "L_uncond": None, + "ifd": None, + "error": "empty response", + } + + if len(full_ids) > max_len: + overflow = len(full_ids) - max_len + full_ids = full_ids[overflow:] + prefix_len = len(prefix_ids) - overflow + prefix_len = max(0, prefix_len) + else: + prefix_len = len(prefix_ids) + + cond_tensor = torch.tensor([full_ids], device=device, dtype=torch.long) + + with torch.inference_mode(): + logits_qa = model(cond_tensor)["logits"][0] + + resp_start = prefix_len - 1 + resp_end = len(full_ids) - 1 + if resp_end <= resp_start: + return { + "L_cond": None, + "L_uncond": None, + "ifd": None, + "error": "response truncated entirely", + } + + resp_logits = logits_qa[resp_start:resp_end] + resp_targets = torch.tensor(full_ids[prefix_len:], device=device, dtype=torch.long) + L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item() + + resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long) + + with torch.inference_mode(): + logits_resp = model(resp_tensor)["logits"][0] + + unp_logits = logits_resp[:-1] + unp_targets = resp_tensor[0, 1:] + L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item() + + ifd = L_cond / L_uncond if L_uncond > 0 else None + + return { + "L_cond": round(L_cond, 6), + "L_uncond": round(L_uncond, 6), + "ifd": round(ifd, 6) if ifd is not None else None, + "instr_len": prefix_len, + "resp_len": len(resp_ids), + "error": None, + } + + def process_file( param_path: str, input_file: str, @@ -96,6 +186,7 @@ def process_file( instr_key: str, resp_key: str, max_len: int, + use_chat_template: bool = False, ): device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 @@ -105,6 +196,12 @@ def process_file( model.to(device=device, dtype=dtype) model.eval() + if use_chat_template and tokenizer._chat_template is None: + raise RuntimeError( + "--use_chat_template specified but tokenizer has no chat template. " + "Add a chat_template to tokenizer_config.json or omit the flag." + ) + with open(input_file, "r", encoding="utf-8") as f: data = [json.loads(line) for line in f if line.strip()] @@ -116,7 +213,13 @@ def process_file( instruction = item[instr_key] response = item[resp_key] scores = compute_ifd( - model, tokenizer, instruction, response, device, max_len + model, + tokenizer, + instruction, + response, + device, + max_len, + use_chat_template=use_chat_template, ) ifd_values.append(scores["ifd"]) results.append({**item, "ifd": scores["ifd"], "ifd_detail": scores}) @@ -167,6 +270,12 @@ def main(): default=2048, help="Max token length (instruction truncated to fit)", ) + parser.add_argument( + "--no_chat_template", + action="store_true", + default=False, + help="Disable chat template, use raw text concatenation", + ) args = parser.parse_args() process_file( @@ -176,6 +285,7 @@ def main(): args.instr_key, args.resp_key, args.max_len, + use_chat_template=not args.no_chat_template, )