"""IFD (Instruction Following Difficulty) data quality scoring. Computes IFD scores for instruction-response pairs to guide data selection. IFD = conditional_NLL / unconditional_NLL, where: - conditional_NLL: average CE loss on response tokens given instruction context - unconditional_NLL: average CE loss on response tokens alone Higher IFD (close to 1) = instruction provides less help = harder sample. Lower IFD (close to 0) = instruction provides strong guidance = easy sample. IFD > 1 = instruction misleads the model = likely low-quality data. Usage:: python scripts/eval/ifd.py --param_path ./params \ --input data.jsonl --output data_with_ifd.jsonl \ --instr_key instruction --resp_key response Disable chat template:: python scripts/eval/ifd.py --param_path ./params \ --input data.jsonl --output data_with_ifd.jsonl \ --instr_key instruction --resp_key response \ --no_chat_template """ import argparse import json import torch import torch.nn.functional as F import tqdm from astrai.model import AutoModel from astrai.tokenize import AutoTokenizer def compute_ifd( model, tokenizer, instruction: str, response: str, device: str, max_len: int = 2048, use_chat_template: bool = False, ) -> dict: if use_chat_template: return _compute_ifd_with_template( model, tokenizer, instruction, response, device, max_len ) return _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len) def _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len) -> dict: instr_ids = tokenizer.encode(instruction) resp_ids = tokenizer.encode(response) if not resp_ids: return { "L_cond": None, "L_uncond": None, "ifd": None, "error": "empty response", } qa_len = len(instr_ids) + len(resp_ids) if qa_len > max_len: overflow = qa_len - max_len instr_ids = instr_ids[overflow:] instr_len = len(instr_ids) resp_len = len(resp_ids) qa_ids = instr_ids + resp_ids qa_tensor = torch.tensor([qa_ids], device=device, dtype=torch.long) with torch.inference_mode(): logits_qa = model(qa_tensor)["logits"][0] resp_logits = logits_qa[instr_len - 1 : -1] resp_targets = torch.tensor(resp_ids, device=device, dtype=torch.long) L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item() resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long) with torch.inference_mode(): logits_resp = model(resp_tensor)["logits"][0] unp_logits = logits_resp[:-1] unp_targets = resp_tensor[0, 1:] L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item() ifd = L_cond / L_uncond if L_uncond > 0 else None return { "L_cond": round(L_cond, 6), "L_uncond": round(L_uncond, 6), "ifd": round(ifd, 6) if ifd is not None else None, "instr_len": instr_len, "resp_len": resp_len, "error": None, } def _compute_ifd_with_template( model, tokenizer, instruction, response, device, max_len ) -> dict: instr_prefix = tokenizer.apply_chat_template( [{"role": "user", "content": instruction}], tokenize=False, add_generation_prompt=True, ) full_text = tokenizer.apply_chat_template( [ {"role": "user", "content": instruction}, {"role": "assistant", "content": response}, ], tokenize=False, add_generation_prompt=False, ) full_ids = tokenizer.encode(full_text) prefix_ids = tokenizer.encode(instr_prefix) resp_ids = tokenizer.encode(response) if not resp_ids: return { "L_cond": None, "L_uncond": None, "ifd": None, "error": "empty response", } if len(full_ids) > max_len: overflow = len(full_ids) - max_len full_ids = full_ids[overflow:] prefix_len = len(prefix_ids) - overflow prefix_len = max(0, prefix_len) else: prefix_len = len(prefix_ids) cond_tensor = torch.tensor([full_ids], device=device, dtype=torch.long) with torch.inference_mode(): logits_qa = model(cond_tensor)["logits"][0] resp_start = prefix_len - 1 resp_end = len(full_ids) - 1 if resp_end <= resp_start: return { "L_cond": None, "L_uncond": None, "ifd": None, "error": "response truncated entirely", } resp_logits = logits_qa[resp_start:resp_end] resp_targets = torch.tensor(full_ids[prefix_len:], device=device, dtype=torch.long) L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item() resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long) with torch.inference_mode(): logits_resp = model(resp_tensor)["logits"][0] unp_logits = logits_resp[:-1] unp_targets = resp_tensor[0, 1:] L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item() ifd = L_cond / L_uncond if L_uncond > 0 else None return { "L_cond": round(L_cond, 6), "L_uncond": round(L_uncond, 6), "ifd": round(ifd, 6) if ifd is not None else None, "instr_len": prefix_len, "resp_len": len(resp_ids), "error": None, } def process_file( param_path: str, input_file: str, output_file: str, instr_key: str, resp_key: str, max_len: int, use_chat_template: bool = False, ): device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 model = AutoModel.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(param_path) model.to(device=device, dtype=dtype) model.eval() if use_chat_template and tokenizer._chat_template is None: raise RuntimeError( "--use_chat_template specified but tokenizer has no chat template. " "Add a chat_template to tokenizer_config.json or omit the flag." ) with open(input_file, "r", encoding="utf-8") as f: data = [json.loads(line) for line in f if line.strip()] results = [] ifd_values = [] with torch.inference_mode(): for item in tqdm.tqdm(data, desc="Computing IFD", unit="sample"): instruction = item[instr_key] response = item[resp_key] scores = compute_ifd( model, tokenizer, instruction, response, device, max_len, use_chat_template=use_chat_template, ) ifd_values.append(scores["ifd"]) results.append({**item, "ifd": scores["ifd"], "ifd_detail": scores}) with open(output_file, "w", encoding="utf-8") as f: for item in results: f.write(json.dumps(item, ensure_ascii=False) + "\n") valid_ifd = [v for v in ifd_values if v is not None] if valid_ifd: import statistics print(f"\n{'=' * 50}") print(f" Samples: {len(data)}") print(f" Valid IFD: {len(valid_ifd)}") print(f" Mean IFD: {statistics.mean(valid_ifd):.4f}") print(f" Median IFD: {statistics.median(valid_ifd):.4f}") print(f" Stdev IFD: {statistics.stdev(valid_ifd):.4f}") print(f" Min IFD: {min(valid_ifd):.4f}") print(f" Max IFD: {max(valid_ifd):.4f}") print(f"{'=' * 50}") print(f"Results saved to {output_file}") def main(): parser = argparse.ArgumentParser( description="Compute IFD scores for instruction-response data" ) parser.add_argument("--param_path", type=str, required=True, help="Model directory") parser.add_argument("--input", type=str, required=True, help="Input JSONL file") parser.add_argument("--output", type=str, required=True, help="Output JSONL file") parser.add_argument( "--instr_key", type=str, default="instruction", help="Key for instruction field", ) parser.add_argument( "--resp_key", type=str, default="response", help="Key for response field", ) parser.add_argument( "--max_len", type=int, default=2048, help="Max token length (instruction truncated to fit)", ) parser.add_argument( "--no_chat_template", action="store_true", default=False, help="Disable chat template, use raw text concatenation", ) args = parser.parse_args() process_file( args.param_path, args.input, args.output, args.instr_key, args.resp_key, args.max_len, use_chat_template=not args.no_chat_template, ) if __name__ == "__main__": main()