diff --git a/scripts/eval/evaluate_ifd.py b/scripts/eval/evaluate_ifd.py new file mode 100644 index 0000000..4fe43ab --- /dev/null +++ b/scripts/eval/evaluate_ifd.py @@ -0,0 +1,183 @@ +"""IFD (Instruction Following Difficulty) data quality scoring. + +Computes IFD scores for instruction-response pairs to guide data selection. +IFD = conditional_NLL / unconditional_NLL, where: + +- conditional_NLL: average CE loss on response tokens given instruction context +- unconditional_NLL: average CE loss on response tokens alone + +Higher IFD (close to 1) = instruction provides less help = harder sample. +Lower IFD (close to 0) = instruction provides strong guidance = easy sample. +IFD > 1 = instruction misleads the model = likely low-quality data. + +Usage:: + + python scripts/eval/ifd.py --param_path ./params \ + --input data.jsonl --output data_with_ifd.jsonl \ + --instr_key instruction --resp_key response +""" + +import argparse +import json + +import torch +import torch.nn.functional as F +import tqdm + +from astrai.model import AutoModel +from astrai.tokenize import AutoTokenizer + + +def compute_ifd( + model, + tokenizer, + instruction: str, + response: str, + device: str, + max_len: int = 2048, +) -> dict: + instr_ids = tokenizer.encode(instruction) + resp_ids = tokenizer.encode(response) + + if not resp_ids: + return { + "L_cond": None, + "L_uncond": None, + "ifd": None, + "error": "empty response", + } + + # Truncate instruction if total length exceeds max_len + qa_len = len(instr_ids) + len(resp_ids) + if qa_len > max_len: + overflow = qa_len - max_len + instr_ids = instr_ids[overflow:] + + instr_len = len(instr_ids) + resp_len = len(resp_ids) + + # Conditional: instruction + response + qa_ids = instr_ids + resp_ids + qa_tensor = torch.tensor([qa_ids], device=device, dtype=torch.long) + + with torch.inference_mode(): + logits_qa = model(qa_tensor)["logits"][0] # [qa_len, vocab] + + resp_logits = logits_qa[instr_len - 1 : -1] # predict response tokens + resp_targets = torch.tensor(resp_ids, device=device, dtype=torch.long) + L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item() + + # Unconditional: response alone + resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long) + + with torch.inference_mode(): + logits_resp = model(resp_tensor)["logits"][0] # [resp_len, vocab] + + unp_logits = logits_resp[:-1] # causal shift + unp_targets = resp_tensor[0, 1:] + L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item() + + ifd = L_cond / L_uncond if L_uncond > 0 else None + + return { + "L_cond": round(L_cond, 6), + "L_uncond": round(L_uncond, 6), + "ifd": round(ifd, 6) if ifd is not None else None, + "instr_len": instr_len, + "resp_len": resp_len, + "error": None, + } + + +def process_file( + param_path: str, + input_file: str, + output_file: str, + instr_key: str, + resp_key: str, + max_len: int, +): + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 if device == "cuda" else torch.float32 + + model = AutoModel.from_pretrained(param_path) + tokenizer = AutoTokenizer.from_pretrained(param_path) + model.to(device=device, dtype=dtype) + model.eval() + + with open(input_file, "r", encoding="utf-8") as f: + data = [json.loads(line) for line in f if line.strip()] + + results = [] + ifd_values = [] + + with torch.inference_mode(): + for item in tqdm.tqdm(data, desc="Computing IFD", unit="sample"): + instruction = item[instr_key] + response = item[resp_key] + scores = compute_ifd( + model, tokenizer, instruction, response, device, max_len + ) + ifd_values.append(scores["ifd"]) + results.append({**item, "ifd": scores["ifd"], "ifd_detail": scores}) + + with open(output_file, "w", encoding="utf-8") as f: + for item in results: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + + valid_ifd = [v for v in ifd_values if v is not None] + if valid_ifd: + import statistics + + print(f"\n{'=' * 50}") + print(f" Samples: {len(data)}") + print(f" Valid IFD: {len(valid_ifd)}") + print(f" Mean IFD: {statistics.mean(valid_ifd):.4f}") + print(f" Median IFD: {statistics.median(valid_ifd):.4f}") + print(f" Stdev IFD: {statistics.stdev(valid_ifd):.4f}") + print(f" Min IFD: {min(valid_ifd):.4f}") + print(f" Max IFD: {max(valid_ifd):.4f}") + print(f"{'=' * 50}") + + print(f"Results saved to {output_file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Compute IFD scores for instruction-response data" + ) + parser.add_argument("--param_path", type=str, required=True, help="Model directory") + parser.add_argument("--input", type=str, required=True, help="Input JSONL file") + parser.add_argument("--output", type=str, required=True, help="Output JSONL file") + parser.add_argument( + "--instr_key", + type=str, + default="instruction", + help="Key for instruction field", + ) + parser.add_argument( + "--resp_key", + type=str, + default="response", + help="Key for response field", + ) + parser.add_argument( + "--max_len", + type=int, + default=2048, + help="Max token length (instruction truncated to fit)", + ) + args = parser.parse_args() + + process_file( + args.param_path, + args.input, + args.output, + args.instr_key, + args.resp_key, + args.max_len, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/tools/perplexity.py b/scripts/eval/evaluate_ppl.py similarity index 100% rename from scripts/tools/perplexity.py rename to scripts/eval/evaluate_ppl.py