feat: IFD 默认使用 chat template,支持裸文本模式
- 新增 _compute_ifd_with_template,用 tokenizer chat template 格式化后计算 IFD - 默认开启 chat template,可通过 --no_chat_template 切换回裸拼接 - chat template 缺失时给出 RuntimeError 提示
This commit is contained in:
parent
a4e5a8c81c
commit
a62c2e11a2
|
|
@ -15,6 +15,13 @@ Usage::
|
|||
python scripts/eval/ifd.py --param_path ./params \
|
||||
--input data.jsonl --output data_with_ifd.jsonl \
|
||||
--instr_key instruction --resp_key response
|
||||
|
||||
Disable chat template::
|
||||
|
||||
python scripts/eval/ifd.py --param_path ./params \
|
||||
--input data.jsonl --output data_with_ifd.jsonl \
|
||||
--instr_key instruction --resp_key response \
|
||||
--no_chat_template
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -35,7 +42,16 @@ def compute_ifd(
|
|||
response: str,
|
||||
device: str,
|
||||
max_len: int = 2048,
|
||||
use_chat_template: bool = False,
|
||||
) -> dict:
|
||||
if use_chat_template:
|
||||
return _compute_ifd_with_template(
|
||||
model, tokenizer, instruction, response, device, max_len
|
||||
)
|
||||
return _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len)
|
||||
|
||||
|
||||
def _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len) -> dict:
|
||||
instr_ids = tokenizer.encode(instruction)
|
||||
resp_ids = tokenizer.encode(response)
|
||||
|
||||
|
|
@ -47,7 +63,6 @@ def compute_ifd(
|
|||
"error": "empty response",
|
||||
}
|
||||
|
||||
# Truncate instruction if total length exceeds max_len
|
||||
qa_len = len(instr_ids) + len(resp_ids)
|
||||
if qa_len > max_len:
|
||||
overflow = qa_len - max_len
|
||||
|
|
@ -56,24 +71,22 @@ def compute_ifd(
|
|||
instr_len = len(instr_ids)
|
||||
resp_len = len(resp_ids)
|
||||
|
||||
# Conditional: instruction + response
|
||||
qa_ids = instr_ids + resp_ids
|
||||
qa_tensor = torch.tensor([qa_ids], device=device, dtype=torch.long)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_qa = model(qa_tensor)["logits"][0] # [qa_len, vocab]
|
||||
logits_qa = model(qa_tensor)["logits"][0]
|
||||
|
||||
resp_logits = logits_qa[instr_len - 1 : -1] # predict response tokens
|
||||
resp_logits = logits_qa[instr_len - 1 : -1]
|
||||
resp_targets = torch.tensor(resp_ids, device=device, dtype=torch.long)
|
||||
L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item()
|
||||
|
||||
# Unconditional: response alone
|
||||
resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_resp = model(resp_tensor)["logits"][0] # [resp_len, vocab]
|
||||
logits_resp = model(resp_tensor)["logits"][0]
|
||||
|
||||
unp_logits = logits_resp[:-1] # causal shift
|
||||
unp_logits = logits_resp[:-1]
|
||||
unp_targets = resp_tensor[0, 1:]
|
||||
L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item()
|
||||
|
||||
|
|
@ -89,6 +102,83 @@ def compute_ifd(
|
|||
}
|
||||
|
||||
|
||||
def _compute_ifd_with_template(
|
||||
model, tokenizer, instruction, response, device, max_len
|
||||
) -> dict:
|
||||
instr_prefix = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": instruction}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
full_text = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": instruction},
|
||||
{"role": "assistant", "content": response},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
|
||||
full_ids = tokenizer.encode(full_text)
|
||||
prefix_ids = tokenizer.encode(instr_prefix)
|
||||
resp_ids = tokenizer.encode(response)
|
||||
|
||||
if not resp_ids:
|
||||
return {
|
||||
"L_cond": None,
|
||||
"L_uncond": None,
|
||||
"ifd": None,
|
||||
"error": "empty response",
|
||||
}
|
||||
|
||||
if len(full_ids) > max_len:
|
||||
overflow = len(full_ids) - max_len
|
||||
full_ids = full_ids[overflow:]
|
||||
prefix_len = len(prefix_ids) - overflow
|
||||
prefix_len = max(0, prefix_len)
|
||||
else:
|
||||
prefix_len = len(prefix_ids)
|
||||
|
||||
cond_tensor = torch.tensor([full_ids], device=device, dtype=torch.long)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_qa = model(cond_tensor)["logits"][0]
|
||||
|
||||
resp_start = prefix_len - 1
|
||||
resp_end = len(full_ids) - 1
|
||||
if resp_end <= resp_start:
|
||||
return {
|
||||
"L_cond": None,
|
||||
"L_uncond": None,
|
||||
"ifd": None,
|
||||
"error": "response truncated entirely",
|
||||
}
|
||||
|
||||
resp_logits = logits_qa[resp_start:resp_end]
|
||||
resp_targets = torch.tensor(full_ids[prefix_len:], device=device, dtype=torch.long)
|
||||
L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item()
|
||||
|
||||
resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_resp = model(resp_tensor)["logits"][0]
|
||||
|
||||
unp_logits = logits_resp[:-1]
|
||||
unp_targets = resp_tensor[0, 1:]
|
||||
L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item()
|
||||
|
||||
ifd = L_cond / L_uncond if L_uncond > 0 else None
|
||||
|
||||
return {
|
||||
"L_cond": round(L_cond, 6),
|
||||
"L_uncond": round(L_uncond, 6),
|
||||
"ifd": round(ifd, 6) if ifd is not None else None,
|
||||
"instr_len": prefix_len,
|
||||
"resp_len": len(resp_ids),
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
def process_file(
|
||||
param_path: str,
|
||||
input_file: str,
|
||||
|
|
@ -96,6 +186,7 @@ def process_file(
|
|||
instr_key: str,
|
||||
resp_key: str,
|
||||
max_len: int,
|
||||
use_chat_template: bool = False,
|
||||
):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
|
|
@ -105,6 +196,12 @@ def process_file(
|
|||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
if use_chat_template and tokenizer._chat_template is None:
|
||||
raise RuntimeError(
|
||||
"--use_chat_template specified but tokenizer has no chat template. "
|
||||
"Add a chat_template to tokenizer_config.json or omit the flag."
|
||||
)
|
||||
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
data = [json.loads(line) for line in f if line.strip()]
|
||||
|
||||
|
|
@ -116,7 +213,13 @@ def process_file(
|
|||
instruction = item[instr_key]
|
||||
response = item[resp_key]
|
||||
scores = compute_ifd(
|
||||
model, tokenizer, instruction, response, device, max_len
|
||||
model,
|
||||
tokenizer,
|
||||
instruction,
|
||||
response,
|
||||
device,
|
||||
max_len,
|
||||
use_chat_template=use_chat_template,
|
||||
)
|
||||
ifd_values.append(scores["ifd"])
|
||||
results.append({**item, "ifd": scores["ifd"], "ifd_detail": scores})
|
||||
|
|
@ -167,6 +270,12 @@ def main():
|
|||
default=2048,
|
||||
help="Max token length (instruction truncated to fit)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_chat_template",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable chat template, use raw text concatenation",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
process_file(
|
||||
|
|
@ -176,6 +285,7 @@ def main():
|
|||
args.instr_key,
|
||||
args.resp_key,
|
||||
args.max_len,
|
||||
use_chat_template=not args.no_chat_template,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue