feat : 新增 HumanEval pass@k 代码生成评测
- InferenceEngine.generate() 批量生成 n 个补全 - 正则提取函数体 + 停止符截断 - multiprocessing sandbox 执行 + timeout 保护 - 标准无偏 pass@k 公式 (1, 10, 100)
This commit is contained in:
parent
02a7cb9fa0
commit
615ba5d8ef
|
|
@ -0,0 +1,336 @@
|
||||||
|
"""HumanEval code generation benchmark.
|
||||||
|
|
||||||
|
Generates n completions per problem, extracts function bodies, executes
|
||||||
|
against hidden tests, and computes pass@k.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
python scripts/tools/evaluate_humaneval.py --param_path ./params \
|
||||||
|
--data_path HumanEval.jsonl.gz --output results.json \
|
||||||
|
--num_samples 200 --temperature 0.8 --max_tokens 512
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from math import prod
|
||||||
|
from multiprocessing import Process, Queue
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from astrai.inference import InferenceEngine
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
HUMANEVAL_URL = (
|
||||||
|
"https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
_STOP_SEQUENCES = [
|
||||||
|
"\nclass ",
|
||||||
|
"\ndef ",
|
||||||
|
"\n# ",
|
||||||
|
"\nif __name__",
|
||||||
|
"\nprint(",
|
||||||
|
"\n\n\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _download_humaneval(data_path: str):
|
||||||
|
if os.path.exists(data_path):
|
||||||
|
return
|
||||||
|
import gzip
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
|
||||||
|
print(f"Downloading HumanEval from {HUMANEVAL_URL} ...")
|
||||||
|
tmp = data_path + ".tmp"
|
||||||
|
urllib.request.urlretrieve(HUMANEVAL_URL, tmp)
|
||||||
|
with gzip.open(tmp, "rb") as f_in:
|
||||||
|
with open(data_path, "wb") as f_out:
|
||||||
|
f_out.write(f_in.read())
|
||||||
|
os.remove(tmp)
|
||||||
|
print(f" saved to {data_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_problems(data_path: str) -> List[dict]:
|
||||||
|
problems = []
|
||||||
|
with open(data_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
problems.append(json.loads(line))
|
||||||
|
return problems
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_function_body(code: str, entry_point: str) -> Optional[str]:
|
||||||
|
"""Extract the function body from a completion."""
|
||||||
|
pattern = rf"def\s+{re.escape(entry_point)}\b[^:]*:"
|
||||||
|
match = re.search(pattern, code)
|
||||||
|
if not match:
|
||||||
|
# Use the full code as-is if we can't find the function
|
||||||
|
return code
|
||||||
|
|
||||||
|
body_start = match.end()
|
||||||
|
lines = code[body_start:].split("\n")
|
||||||
|
body_lines = []
|
||||||
|
started = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
stripped = line.rstrip()
|
||||||
|
if not stripped and not started:
|
||||||
|
continue
|
||||||
|
if not stripped and started:
|
||||||
|
body_lines.append("")
|
||||||
|
continue
|
||||||
|
if not started:
|
||||||
|
started = True
|
||||||
|
if stripped.lstrip() == stripped and started:
|
||||||
|
break
|
||||||
|
body_lines.append(stripped)
|
||||||
|
|
||||||
|
body = "\n".join(body_lines)
|
||||||
|
if not body.strip():
|
||||||
|
return None
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def _trim_stop_sequences(text: str) -> str:
|
||||||
|
for stop in _STOP_SEQUENCES:
|
||||||
|
idx = text.find(stop)
|
||||||
|
if idx != -1:
|
||||||
|
text = text[:idx]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_code(problem: dict, completion: str, timeout: float = 3.0) -> bool:
|
||||||
|
"""Run the completion against hidden tests in a subprocess."""
|
||||||
|
|
||||||
|
def _worker(queue, full_code):
|
||||||
|
try:
|
||||||
|
namespace = {}
|
||||||
|
exec(full_code, namespace)
|
||||||
|
check = namespace.get("check")
|
||||||
|
if check is None:
|
||||||
|
queue.put(False)
|
||||||
|
return
|
||||||
|
check(namespace.get(problem["entry_point"]))
|
||||||
|
queue.put(True)
|
||||||
|
except Exception:
|
||||||
|
queue.put(False)
|
||||||
|
|
||||||
|
full_code = problem["prompt"] + completion + "\n" + problem["test"]
|
||||||
|
|
||||||
|
queue: Queue = Queue()
|
||||||
|
proc = Process(target=_worker, args=(queue, full_code))
|
||||||
|
proc.start()
|
||||||
|
proc.join(timeout)
|
||||||
|
|
||||||
|
if proc.is_alive():
|
||||||
|
proc.terminate()
|
||||||
|
proc.join()
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
return queue.get_nowait()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _pass_at_k(n: int, c: int, k: int) -> float:
|
||||||
|
"""Unbiased estimator of pass@k."""
|
||||||
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
|
return 1.0 - float(prod(1.0 - k / np.arange(n - c + 1, n + 1)))
|
||||||
|
|
||||||
|
|
||||||
|
def _deduplicate(completions: List[str]) -> List[str]:
|
||||||
|
seen = set()
|
||||||
|
unique = []
|
||||||
|
for c in completions:
|
||||||
|
if c not in seen:
|
||||||
|
seen.add(c)
|
||||||
|
unique.append(c)
|
||||||
|
return unique
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
engine: InferenceEngine,
|
||||||
|
prompt: str,
|
||||||
|
num_samples: int,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
batch_size: int,
|
||||||
|
) -> List[str]:
|
||||||
|
batches = [prompt] * min(batch_size, num_samples)
|
||||||
|
completions = []
|
||||||
|
remaining = num_samples
|
||||||
|
|
||||||
|
while remaining > 0:
|
||||||
|
current = min(batch_size, remaining)
|
||||||
|
batch_prompts = batches[:current]
|
||||||
|
outputs = engine.generate(
|
||||||
|
prompt=batch_prompts,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
if isinstance(outputs, str):
|
||||||
|
outputs = [outputs]
|
||||||
|
completions.extend(outputs)
|
||||||
|
remaining -= current
|
||||||
|
|
||||||
|
return _deduplicate(completions)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
engine: InferenceEngine,
|
||||||
|
problems: List[dict],
|
||||||
|
num_samples: int,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
batch_size: int,
|
||||||
|
k_values: Tuple[int, ...] = (1, 10, 100),
|
||||||
|
) -> Dict:
|
||||||
|
results = {}
|
||||||
|
all_pass_at_k = {k: [] for k in k_values}
|
||||||
|
|
||||||
|
for problem in tqdm.tqdm(problems, desc="HumanEval", unit="problem"):
|
||||||
|
task_id = problem["task_id"]
|
||||||
|
prompt = problem["prompt"]
|
||||||
|
entry_point = problem["entry_point"]
|
||||||
|
|
||||||
|
raw_completions = _generate(
|
||||||
|
engine,
|
||||||
|
prompt,
|
||||||
|
num_samples,
|
||||||
|
max_tokens,
|
||||||
|
temperature,
|
||||||
|
top_p,
|
||||||
|
top_k,
|
||||||
|
batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
completions = []
|
||||||
|
for raw in raw_completions:
|
||||||
|
trimmed = _trim_stop_sequences(raw)
|
||||||
|
body = _extract_function_body(trimmed, entry_point)
|
||||||
|
if body:
|
||||||
|
completions.append(body)
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
for comp in completions:
|
||||||
|
if _execute_code(problem, comp):
|
||||||
|
passed += 1
|
||||||
|
|
||||||
|
n = len(completions)
|
||||||
|
c = passed
|
||||||
|
result = {"task_id": task_id, "n": n, "passed": c}
|
||||||
|
for k in k_values:
|
||||||
|
result[f"pass@{k}"] = round(_pass_at_k(n, c, k), 4)
|
||||||
|
all_pass_at_k[k].append(_pass_at_k(n, c, k))
|
||||||
|
results[task_id] = result
|
||||||
|
|
||||||
|
summary = {}
|
||||||
|
for k in k_values:
|
||||||
|
vals = all_pass_at_k[k]
|
||||||
|
summary[f"pass@{k}"] = round(float(np.mean(vals)), 4)
|
||||||
|
results["_summary"] = summary
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="HumanEval benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--param_path", type=str, default="./params", help="Model directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_path",
|
||||||
|
type=str,
|
||||||
|
default="./humaneval/HumanEval.jsonl",
|
||||||
|
help="HumanEval JSONL file (auto-download if missing)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_samples",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="Completions per problem",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_tokens", type=int, default=512, help="Max generation tokens"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature", type=float, default=0.8, help="Sampling temperature"
|
||||||
|
)
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
|
||||||
|
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Inference batch size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--problems",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Specific problem indices (0-based)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
_download_humaneval(args.data_path)
|
||||||
|
problems = _load_problems(args.data_path)
|
||||||
|
if args.problems:
|
||||||
|
problems = [problems[i] for i in args.problems if i < len(problems)]
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(args.param_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
||||||
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
engine = InferenceEngine(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = evaluate(
|
||||||
|
engine=engine,
|
||||||
|
problems=problems,
|
||||||
|
num_samples=args.num_samples,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
top_k=args.top_k,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
k_values=(1, 10, 100),
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = results.pop("_summary")
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
for k, v in summary.items():
|
||||||
|
print(f" {k}: {v:.2%}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
results["_summary"] = summary
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"Results saved to {args.output}")
|
||||||
|
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue