70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
|
|
from llm_eval import EvalFactory
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="LLM Evaluation Benchmark (via HTTP API)")
|
|
parser.add_argument("--api_base", type=str, default="http://localhost:8000",
|
|
help="API base URL (default: http://localhost:8000)")
|
|
parser.add_argument("--api_key", type=str, default="not-needed",
|
|
help="API key")
|
|
parser.add_argument("--model", type=str, default="default",
|
|
help="Model name sent in request body")
|
|
parser.add_argument("--eval_type", type=str, default="mmlu",
|
|
choices=EvalFactory.list_registered(),
|
|
help="Evaluation task")
|
|
parser.add_argument("--data_path", type=str, required=True,
|
|
help="Dataset directory")
|
|
parser.add_argument("--subject", type=str, default="all",
|
|
help="Subject (default: all)")
|
|
parser.add_argument("--mode", type=str, default="logprobs",
|
|
choices=["logprobs", "generation"],
|
|
help="Scoring mode")
|
|
parser.add_argument("--output_file", type=str, default=None,
|
|
help="Path to save results JSON")
|
|
parser.add_argument("--max_retries", type=int, default=3)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
evaluator = EvalFactory.create(
|
|
args.eval_type,
|
|
api_base=args.api_base,
|
|
api_key=args.api_key,
|
|
model=args.model,
|
|
subject=args.subject,
|
|
mode=args.mode,
|
|
max_retries=args.max_retries,
|
|
)
|
|
|
|
print(f"Running {args.eval_type} (subject={args.subject}, mode={args.mode})...")
|
|
print(f"API: {args.api_base}")
|
|
result = evaluator.evaluate(data_path=args.data_path)
|
|
|
|
print(f"\n{'='*50}")
|
|
print(f"Task: {result.task_name}")
|
|
print(f"Samples: {result.num_samples}")
|
|
print(f"Acc: {result.accuracy:.4f} ({result.accuracy*100:.2f}%)")
|
|
print(f"{'='*50}")
|
|
|
|
if args.output_file:
|
|
os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
|
|
with open(args.output_file, "w") as f:
|
|
json.dump({
|
|
"task_name": result.task_name,
|
|
"num_samples": result.num_samples,
|
|
"accuracy": result.accuracy,
|
|
"metadata": result.metadata,
|
|
"results": result.results,
|
|
}, f, ensure_ascii=False, indent=2)
|
|
print(f"Saved to {args.output_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|