601 lines
18 KiB
Python
601 lines
18 KiB
Python
"""IFEval instruction-following evaluation benchmark.
|
|
|
|
Evaluates model responses against regex-based constraint verifiers.
|
|
Supports all IFEval constraint types except language detection.
|
|
|
|
Usage::
|
|
|
|
python scripts/tools/evaluate_ifeval.py --param_path ./params \
|
|
--data_path ifeval.jsonl --output results.json \
|
|
--temperature 0.1 --max_tokens 512
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import urllib.request
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
import torch
|
|
import tqdm
|
|
|
|
from astrai.inference import InferenceEngine
|
|
from astrai.model import AutoModel
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
IFEVAL_URL = (
|
|
"https://raw.githubusercontent.com/google-research/"
|
|
"google-research/master/instruction_following_eval/data/input_data.jsonl"
|
|
)
|
|
|
|
CONSTRAINT_VERIFIERS: Dict[str, Callable[[str, dict], bool]] = {}
|
|
|
|
|
|
def register(instruction_id: str):
|
|
def decorator(fn):
|
|
CONSTRAINT_VERIFIERS[instruction_id] = fn
|
|
return fn
|
|
|
|
return decorator
|
|
|
|
|
|
@register("keywords:existence")
|
|
def check_keyword_existence(response: str, kwargs: dict) -> bool:
|
|
for kw in kwargs["keywords"]:
|
|
if not re.search(re.escape(kw), response, re.IGNORECASE):
|
|
return False
|
|
return True
|
|
|
|
|
|
@register("keywords:frequency")
|
|
def check_keyword_frequency(response: str, kwargs: dict) -> bool:
|
|
keyword = kwargs["keyword"]
|
|
frequency = kwargs.get("frequency", 1)
|
|
relation = kwargs.get("relation", "at least")
|
|
count = len(re.findall(re.escape(keyword), response, re.IGNORECASE))
|
|
if relation == "less than":
|
|
return count < frequency
|
|
return count >= frequency
|
|
|
|
|
|
@register("keywords:forbidden_words")
|
|
def check_forbidden_words(response: str, kwargs: dict) -> bool:
|
|
for word in kwargs["forbidden_words"]:
|
|
if re.search(r"\b" + re.escape(word) + r"\b", response, re.IGNORECASE):
|
|
return False
|
|
return True
|
|
|
|
|
|
@register("keywords:letter_frequency")
|
|
def check_letter_frequency(response: str, kwargs: dict) -> bool:
|
|
letter = kwargs["letter"].lower()
|
|
frequency = kwargs.get("let_frequency", 1)
|
|
relation = kwargs.get("let_relation", "at least")
|
|
count = response.lower().count(letter)
|
|
if relation == "less than":
|
|
return count < frequency
|
|
return count >= frequency
|
|
|
|
|
|
@register("detectable_content:number_placeholders")
|
|
def check_placeholders(response: str, kwargs: dict) -> bool:
|
|
num = kwargs.get("num_placeholders", 1)
|
|
placeholders = re.findall(r"\[.*?\]", response)
|
|
return len(placeholders) >= num
|
|
|
|
|
|
@register("detectable_content:postscript")
|
|
def check_postscript(response: str, kwargs: dict) -> bool:
|
|
marker = kwargs.get("postscript_marker", "P.S.")
|
|
response_lower = response.lower()
|
|
if marker == "P.P.S":
|
|
return bool(re.search(r"p\.\s?p\.\s?s", response_lower))
|
|
elif marker == "P.S.":
|
|
return bool(re.search(r"p\.\s?s\.", response_lower))
|
|
else:
|
|
return bool(re.search(re.escape(marker.lower()), response_lower))
|
|
|
|
|
|
@register("detectable_format:number_bullet_lists")
|
|
def check_bullet_lists(response: str, kwargs: dict) -> bool:
|
|
num = kwargs.get("num_bullets", 1)
|
|
bullets = re.findall(r"^\s*\*[^\*].*$", response, re.MULTILINE)
|
|
dashes = re.findall(r"^\s*-.*$", response, re.MULTILINE)
|
|
return len(bullets) + len(dashes) == num
|
|
|
|
|
|
@register("detectable_format:number_highlighted_sections")
|
|
def check_highlighted_sections(response: str, kwargs: dict) -> bool:
|
|
num = kwargs.get("num_highlights", 1)
|
|
highlights = re.findall(r"\*[^\n\*]+\*", response)
|
|
count = 0
|
|
for h in highlights:
|
|
if h.strip("*").strip():
|
|
count += 1
|
|
return count >= num
|
|
|
|
|
|
@register("detectable_format:multiple_sections")
|
|
def check_multiple_sections(response: str, kwargs: dict) -> bool:
|
|
splitter = kwargs.get("section_spliter", "Section")
|
|
num = kwargs.get("num_sections", 1)
|
|
pattern = r"\s?" + re.escape(splitter) + r"\s?\d+\s?"
|
|
sections = re.split(pattern, response)
|
|
return len(sections) - 1 >= num
|
|
|
|
|
|
@register("detectable_format:title")
|
|
def check_title(response: str, kwargs: dict) -> bool:
|
|
titles = re.findall(r"<<[^>\n]+>>", response)
|
|
for title in titles:
|
|
if title.strip("<>").strip():
|
|
return True
|
|
return False
|
|
|
|
|
|
@register("detectable_format:json_format")
|
|
def check_json_format(response: str, kwargs: dict) -> bool:
|
|
value = response.strip()
|
|
for prefix in ("```json", "```Json", "```JSON", "```"):
|
|
if value.lower().startswith(prefix.lower()):
|
|
value = value[len(prefix) :].strip()
|
|
if value.endswith("```"):
|
|
value = value[:-3].strip()
|
|
try:
|
|
json.loads(value)
|
|
return True
|
|
except (ValueError, json.JSONDecodeError):
|
|
return False
|
|
|
|
|
|
@register("detectable_format:general_punctuation")
|
|
def check_general_punctuation(response: str, kwargs: dict) -> bool:
|
|
punctuation_blacklist = kwargs.get("punctuation_blacklist", [])
|
|
for punct in punctuation_blacklist:
|
|
if punct in response:
|
|
return False
|
|
return True
|
|
|
|
|
|
@register("detectable_format:number_highlighted_words")
|
|
def check_highlighted_words(response: str, kwargs: dict) -> bool:
|
|
num = kwargs.get("num_highlights", 1)
|
|
highlights = re.findall(r"\*[^\s\*][^\*]*[^\s\*]\*", response)
|
|
return len(highlights) >= num
|
|
|
|
|
|
@register("startend:end_checker")
|
|
def check_end_checker(response: str, kwargs: dict) -> bool:
|
|
end_phrase = kwargs["end_phrase"]
|
|
return (
|
|
response.strip()
|
|
.rstrip('"')
|
|
.rstrip()
|
|
.lower()
|
|
.endswith(end_phrase.strip().lower())
|
|
)
|
|
|
|
|
|
@register("startend:quotation")
|
|
def check_quotation(response: str, kwargs: dict) -> bool:
|
|
value = response.strip()
|
|
return value.startswith('"') and value.endswith('"')
|
|
|
|
|
|
@register("startend:start_checker")
|
|
def check_start_checker(response: str, kwargs: dict) -> bool:
|
|
starter = kwargs["starter"]
|
|
return bool(re.search(r"^\s*" + re.escape(starter), response, re.MULTILINE))
|
|
|
|
|
|
@register("change_case:english_capital")
|
|
def check_english_capital(response: str, kwargs: dict) -> bool:
|
|
return response.isupper()
|
|
|
|
|
|
@register("change_case:english_lowercase")
|
|
def check_english_lowercase(response: str, kwargs: dict) -> bool:
|
|
return response.islower()
|
|
|
|
|
|
@register("change_case:capital_word_frequency")
|
|
def check_capital_word_frequency(response: str, kwargs: dict) -> bool:
|
|
frequency = kwargs.get("capital_frequency", 1)
|
|
relation = kwargs.get("capital_relation", "at least")
|
|
capital_words = re.findall(r"\b[A-Z]{2,}\b", response)
|
|
count = len(capital_words)
|
|
if relation == "less than":
|
|
return count < frequency
|
|
return count >= frequency
|
|
|
|
|
|
@register("punctuation:no_comma")
|
|
def check_no_comma(response: str, kwargs: dict) -> bool:
|
|
return "," not in response
|
|
|
|
|
|
def count_words(text: str) -> int:
|
|
return len(re.findall(r"\b\w+\b", text))
|
|
|
|
|
|
def count_sentences(text: str) -> int:
|
|
text = text.strip()
|
|
if not text:
|
|
return 0
|
|
sentences = re.split(r"(?<=[.!?])\s+", text)
|
|
return len([s for s in sentences if s.strip()])
|
|
|
|
|
|
@register("length_constraints:number_words")
|
|
def check_number_words(response: str, kwargs: dict) -> bool:
|
|
num = kwargs.get("num_words", 100)
|
|
relation = kwargs.get("relation", "at least")
|
|
cnt = count_words(response)
|
|
if relation == "less than":
|
|
return cnt < num
|
|
return cnt >= num
|
|
|
|
|
|
@register("length_constraints:number_sentences")
|
|
def check_number_sentences(response: str, kwargs: dict) -> bool:
|
|
num = kwargs.get("num_sentences", 5)
|
|
relation = kwargs.get("relation", "at least")
|
|
cnt = count_sentences(response)
|
|
if relation == "less than":
|
|
return cnt < num
|
|
return cnt >= num
|
|
|
|
|
|
@register("length_constraints:number_paragraphs")
|
|
def check_number_paragraphs(response: str, kwargs: dict) -> bool:
|
|
num = kwargs.get("num_paragraphs", 1)
|
|
if "***" in response:
|
|
paragraphs = re.split(r"\s?\*\*\*\s?", response)
|
|
else:
|
|
paragraphs = re.split(r"\n\n+", response)
|
|
actual = len([p for p in paragraphs if p.strip()])
|
|
return actual == num
|
|
|
|
|
|
@register("length_constraints:nth_paragraph_first_word")
|
|
def check_nth_paragraph_first_word(response: str, kwargs: dict) -> bool:
|
|
num_paragraphs = kwargs.get("num_paragraphs", 1)
|
|
nth = kwargs.get("nth_paragraph", 1)
|
|
first_word = kwargs.get("first_word", "").lower()
|
|
|
|
paragraphs = re.split(r"\n\n+", response)
|
|
paragraphs = [p.strip() for p in paragraphs if p.strip()]
|
|
|
|
if len(paragraphs) != num_paragraphs:
|
|
return False
|
|
if nth > len(paragraphs):
|
|
return False
|
|
|
|
target = paragraphs[nth - 1]
|
|
words = target.split()
|
|
if not words:
|
|
return False
|
|
|
|
word = words[0].strip().lstrip("'\"").rstrip(".,!?:;\"'")
|
|
return word.lower() == first_word
|
|
|
|
|
|
@register("length_constraints:nth_word_checker")
|
|
def check_nth_word(response: str, kwargs: dict) -> bool:
|
|
nth = kwargs.get("nth_word", 1)
|
|
target = kwargs.get("target_word", "").lower()
|
|
words = re.findall(r"\b\w+\b", response)
|
|
if nth > len(words):
|
|
return False
|
|
return words[nth - 1].lower() == target
|
|
|
|
|
|
@register("combination:repeat_prompt")
|
|
def check_repeat_prompt(response: str, kwargs: dict) -> bool:
|
|
prompt = kwargs["prompt_to_repeat"]
|
|
return response.strip().lower().startswith(prompt.strip().lower())
|
|
|
|
|
|
@register("combination:two_responses")
|
|
def check_two_responses(response: str, kwargs: dict) -> bool:
|
|
parts = response.split("******")
|
|
valid = [p for p in parts if p.strip()]
|
|
if len(valid) != 2:
|
|
return False
|
|
return valid[0].strip() != valid[1].strip()
|
|
|
|
|
|
def download_ifeval(data_path: str):
|
|
if os.path.exists(data_path):
|
|
return
|
|
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
|
|
print(f"Downloading IFEval from {IFEVAL_URL} ...")
|
|
tmp = data_path + ".tmp"
|
|
urllib.request.urlretrieve(IFEVAL_URL, tmp)
|
|
with open(tmp, "rb") as f_in:
|
|
content = f_in.read()
|
|
with open(data_path, "wb") as f_out:
|
|
f_out.write(content)
|
|
os.remove(tmp)
|
|
print(f" saved to {data_path}")
|
|
|
|
|
|
def load_problems(data_path: str) -> List[dict]:
|
|
problems = []
|
|
with open(data_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
problems.append(json.loads(line))
|
|
return problems
|
|
|
|
|
|
def verify_response(response: str, instruction_id: str, kwargs: dict) -> Optional[bool]:
|
|
verifier = CONSTRAINT_VERIFIERS.get(instruction_id)
|
|
if verifier is None:
|
|
return None
|
|
try:
|
|
return verifier(response, kwargs)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def generate_one(
|
|
engine: InferenceEngine,
|
|
prompt: str,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
top_p: float,
|
|
top_k: int,
|
|
) -> str:
|
|
output = engine.generate(
|
|
prompt=prompt,
|
|
stream=False,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
)
|
|
if isinstance(output, list):
|
|
return output[0]
|
|
return output
|
|
|
|
|
|
def evaluate(
|
|
engine: InferenceEngine,
|
|
problems: List[dict],
|
|
max_tokens: int,
|
|
temperature: float,
|
|
top_p: float,
|
|
top_k: int,
|
|
num_samples: int = 1,
|
|
) -> Dict:
|
|
results = {}
|
|
constraint_stats: Dict[str, Dict[str, int]] = {}
|
|
total_constraints = 0
|
|
total_passed = 0
|
|
|
|
for problem in tqdm.tqdm(problems, desc="IFEval", unit="problem"):
|
|
key = problem["key"]
|
|
prompt = problem["prompt"]
|
|
instruction_ids = problem["instruction_id_list"]
|
|
kwargs_list = problem["kwargs"]
|
|
|
|
samples = []
|
|
for _ in range(num_samples):
|
|
response = generate_one(
|
|
engine, prompt, max_tokens, temperature, top_p, top_k
|
|
)
|
|
samples.append(response)
|
|
|
|
constraint_results = []
|
|
passed = 0
|
|
verified = 0
|
|
|
|
for idx, instruction_id in enumerate(instruction_ids):
|
|
kwargs = kwargs_list[idx] if idx < len(kwargs_list) else {}
|
|
best_pass = False
|
|
for response in samples:
|
|
result = verify_response(response, instruction_id, kwargs)
|
|
if result is None:
|
|
continue
|
|
if result:
|
|
best_pass = True
|
|
break
|
|
|
|
verifier_exists = instruction_id in CONSTRAINT_VERIFIERS
|
|
if verifier_exists:
|
|
verified += 1
|
|
if best_pass:
|
|
passed += 1
|
|
|
|
constraint_results.append(
|
|
{
|
|
"instruction_id": instruction_id,
|
|
"passed": best_pass,
|
|
"supported": verifier_exists,
|
|
"kwargs": kwargs,
|
|
}
|
|
)
|
|
|
|
if verifier_exists:
|
|
if instruction_id not in constraint_stats:
|
|
constraint_stats[instruction_id] = {
|
|
"total": 0,
|
|
"passed": 0,
|
|
}
|
|
constraint_stats[instruction_id]["total"] += 1
|
|
if best_pass:
|
|
constraint_stats[instruction_id]["passed"] += 1
|
|
|
|
total_constraints += verified
|
|
total_passed += passed
|
|
|
|
accuracy = passed / verified if verified > 0 else None
|
|
results[str(key)] = {
|
|
"key": key,
|
|
"prompt": prompt,
|
|
"response": samples[0],
|
|
"num_samples": num_samples,
|
|
"num_constraints": len(instruction_ids),
|
|
"num_verified": verified,
|
|
"num_passed": passed,
|
|
"accuracy": round(accuracy, 4) if accuracy is not None else None,
|
|
"constraints": constraint_results,
|
|
}
|
|
|
|
overall_accuracy = (
|
|
round(total_passed / total_constraints, 4) if total_constraints > 0 else 0.0
|
|
)
|
|
|
|
type_summary = {}
|
|
for inst_id, stats in sorted(constraint_stats.items()):
|
|
type_summary[inst_id] = {
|
|
"total": stats["total"],
|
|
"passed": stats["passed"],
|
|
"accuracy": round(stats["passed"] / stats["total"], 4)
|
|
if stats["total"] > 0
|
|
else 0.0,
|
|
}
|
|
|
|
unsupported_count = sum(
|
|
1
|
|
for p in problems
|
|
for iid in p["instruction_id_list"]
|
|
if iid not in CONSTRAINT_VERIFIERS
|
|
)
|
|
|
|
results["_summary"] = {
|
|
"total_problems": len(problems),
|
|
"total_constraints": total_constraints,
|
|
"total_passed": total_passed,
|
|
"overall_accuracy": overall_accuracy,
|
|
"unsupported_constraints": unsupported_count,
|
|
"supported_types": sorted(CONSTRAINT_VERIFIERS.keys()),
|
|
"per_type_accuracy": type_summary,
|
|
}
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="IFEval benchmark")
|
|
parser.add_argument(
|
|
"--param_path", type=str, default="./params", help="Model directory"
|
|
)
|
|
parser.add_argument(
|
|
"--data_path",
|
|
type=str,
|
|
default="./ifeval/input_data.jsonl",
|
|
help="IFEval JSONL file (auto-download if missing)",
|
|
)
|
|
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
|
|
parser.add_argument(
|
|
"--max_tokens", type=int, default=512, help="Max generation tokens"
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=0.1,
|
|
help="Sampling temperature",
|
|
)
|
|
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
|
|
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
|
|
parser.add_argument(
|
|
"--num_samples",
|
|
type=int,
|
|
default=1,
|
|
help="Number of samples per problem (best-of-n scoring)",
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size", type=int, default=1, help="Inference batch size"
|
|
)
|
|
parser.add_argument(
|
|
"--limit",
|
|
type=int,
|
|
default=None,
|
|
help="Limit to first N problems (for quick testing)",
|
|
)
|
|
parser.add_argument(
|
|
"--dump_responses",
|
|
type=str,
|
|
default=None,
|
|
help="Path to dump raw model responses (JSONL)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
download_ifeval(args.data_path)
|
|
problems = load_problems(args.data_path)
|
|
if args.limit:
|
|
problems = problems[: args.limit]
|
|
|
|
print(f"Loaded {len(problems)} problems")
|
|
print(f"Supported constraint types: {len(CONSTRAINT_VERIFIERS)}")
|
|
|
|
model = AutoModel.from_pretrained(args.param_path)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
|
model.to(device="cuda", dtype=torch.bfloat16)
|
|
|
|
engine = InferenceEngine(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
max_batch_size=args.batch_size,
|
|
)
|
|
|
|
results = evaluate(
|
|
engine=engine,
|
|
problems=problems,
|
|
max_tokens=args.max_tokens,
|
|
temperature=args.temperature,
|
|
top_p=args.top_p,
|
|
top_k=args.top_k,
|
|
num_samples=args.num_samples,
|
|
)
|
|
|
|
summary = results.pop("_summary")
|
|
print(f"\n{'=' * 60}")
|
|
print(f" Problems: {summary['total_problems']}")
|
|
print(f" Constraints: {summary['total_constraints']}")
|
|
print(f" Passed: {summary['total_passed']}")
|
|
print(f" Accuracy: {summary['overall_accuracy']:.2%}")
|
|
print(f" Unsupported: {summary['unsupported_constraints']}")
|
|
print(f"{'=' * 60}")
|
|
|
|
print(f"\nPer-type accuracy:")
|
|
for inst_id, stats in sorted(summary["per_type_accuracy"].items()):
|
|
print(
|
|
f" {inst_id:50s} {stats['accuracy']:.2%} "
|
|
f"({stats['passed']}/{stats['total']})"
|
|
)
|
|
|
|
if args.output:
|
|
results["_summary"] = summary
|
|
with open(args.output, "w", encoding="utf-8") as f:
|
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
|
print(f"\nResults saved to {args.output}")
|
|
|
|
if args.dump_responses:
|
|
with open(args.dump_responses, "w", encoding="utf-8") as f:
|
|
for k, v in results.items():
|
|
if k.startswith("_"):
|
|
continue
|
|
f.write(
|
|
json.dumps(
|
|
{
|
|
"key": v["key"],
|
|
"prompt": v["prompt"],
|
|
"response": v["response"],
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
+ "\n"
|
|
)
|
|
print(f"Responses dumped to {args.dump_responses}")
|
|
|
|
engine.shutdown()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|