feat: 新增 IFD 数据质量评分工具, 移动 ppl 至 eval
- 计算指令遵循难度分数用于数据筛选 - IFD = 条件交叉熵 / 无条件交叉熵 - perplexity 移至 scripts/eval/
This commit is contained in:
parent
4e8d1ee24e
commit
1818d06576
|
|
@ -0,0 +1,183 @@
|
|||
"""IFD (Instruction Following Difficulty) data quality scoring.
|
||||
|
||||
Computes IFD scores for instruction-response pairs to guide data selection.
|
||||
IFD = conditional_NLL / unconditional_NLL, where:
|
||||
|
||||
- conditional_NLL: average CE loss on response tokens given instruction context
|
||||
- unconditional_NLL: average CE loss on response tokens alone
|
||||
|
||||
Higher IFD (close to 1) = instruction provides less help = harder sample.
|
||||
Lower IFD (close to 0) = instruction provides strong guidance = easy sample.
|
||||
IFD > 1 = instruction misleads the model = likely low-quality data.
|
||||
|
||||
Usage::
|
||||
|
||||
python scripts/eval/ifd.py --param_path ./params \
|
||||
--input data.jsonl --output data_with_ifd.jsonl \
|
||||
--instr_key instruction --resp_key response
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def compute_ifd(
|
||||
model,
|
||||
tokenizer,
|
||||
instruction: str,
|
||||
response: str,
|
||||
device: str,
|
||||
max_len: int = 2048,
|
||||
) -> dict:
|
||||
instr_ids = tokenizer.encode(instruction)
|
||||
resp_ids = tokenizer.encode(response)
|
||||
|
||||
if not resp_ids:
|
||||
return {
|
||||
"L_cond": None,
|
||||
"L_uncond": None,
|
||||
"ifd": None,
|
||||
"error": "empty response",
|
||||
}
|
||||
|
||||
# Truncate instruction if total length exceeds max_len
|
||||
qa_len = len(instr_ids) + len(resp_ids)
|
||||
if qa_len > max_len:
|
||||
overflow = qa_len - max_len
|
||||
instr_ids = instr_ids[overflow:]
|
||||
|
||||
instr_len = len(instr_ids)
|
||||
resp_len = len(resp_ids)
|
||||
|
||||
# Conditional: instruction + response
|
||||
qa_ids = instr_ids + resp_ids
|
||||
qa_tensor = torch.tensor([qa_ids], device=device, dtype=torch.long)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_qa = model(qa_tensor)["logits"][0] # [qa_len, vocab]
|
||||
|
||||
resp_logits = logits_qa[instr_len - 1 : -1] # predict response tokens
|
||||
resp_targets = torch.tensor(resp_ids, device=device, dtype=torch.long)
|
||||
L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item()
|
||||
|
||||
# Unconditional: response alone
|
||||
resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits_resp = model(resp_tensor)["logits"][0] # [resp_len, vocab]
|
||||
|
||||
unp_logits = logits_resp[:-1] # causal shift
|
||||
unp_targets = resp_tensor[0, 1:]
|
||||
L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item()
|
||||
|
||||
ifd = L_cond / L_uncond if L_uncond > 0 else None
|
||||
|
||||
return {
|
||||
"L_cond": round(L_cond, 6),
|
||||
"L_uncond": round(L_uncond, 6),
|
||||
"ifd": round(ifd, 6) if ifd is not None else None,
|
||||
"instr_len": instr_len,
|
||||
"resp_len": resp_len,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
|
||||
def process_file(
|
||||
param_path: str,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
instr_key: str,
|
||||
resp_key: str,
|
||||
max_len: int,
|
||||
):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
data = [json.loads(line) for line in f if line.strip()]
|
||||
|
||||
results = []
|
||||
ifd_values = []
|
||||
|
||||
with torch.inference_mode():
|
||||
for item in tqdm.tqdm(data, desc="Computing IFD", unit="sample"):
|
||||
instruction = item[instr_key]
|
||||
response = item[resp_key]
|
||||
scores = compute_ifd(
|
||||
model, tokenizer, instruction, response, device, max_len
|
||||
)
|
||||
ifd_values.append(scores["ifd"])
|
||||
results.append({**item, "ifd": scores["ifd"], "ifd_detail": scores})
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for item in results:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
||||
|
||||
valid_ifd = [v for v in ifd_values if v is not None]
|
||||
if valid_ifd:
|
||||
import statistics
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f" Samples: {len(data)}")
|
||||
print(f" Valid IFD: {len(valid_ifd)}")
|
||||
print(f" Mean IFD: {statistics.mean(valid_ifd):.4f}")
|
||||
print(f" Median IFD: {statistics.median(valid_ifd):.4f}")
|
||||
print(f" Stdev IFD: {statistics.stdev(valid_ifd):.4f}")
|
||||
print(f" Min IFD: {min(valid_ifd):.4f}")
|
||||
print(f" Max IFD: {max(valid_ifd):.4f}")
|
||||
print(f"{'=' * 50}")
|
||||
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute IFD scores for instruction-response data"
|
||||
)
|
||||
parser.add_argument("--param_path", type=str, required=True, help="Model directory")
|
||||
parser.add_argument("--input", type=str, required=True, help="Input JSONL file")
|
||||
parser.add_argument("--output", type=str, required=True, help="Output JSONL file")
|
||||
parser.add_argument(
|
||||
"--instr_key",
|
||||
type=str,
|
||||
default="instruction",
|
||||
help="Key for instruction field",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resp_key",
|
||||
type=str,
|
||||
default="response",
|
||||
help="Key for response field",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Max token length (instruction truncated to fit)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
process_file(
|
||||
args.param_path,
|
||||
args.input,
|
||||
args.output,
|
||||
args.instr_key,
|
||||
args.resp_key,
|
||||
args.max_len,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue