From ac814e5c5247e8dac1be988e615f2a37b234fdb0 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 24 May 2026 22:36:37 +0800 Subject: [PATCH] add all project source files --- .gitignore | 5 + llm_eval/__init__.py | 10 ++ llm_eval/base.py | 24 +++++ llm_eval/cli.py | 69 +++++++++++++ llm_eval/mmlu.py | 208 +++++++++++++++++++++++++++++++++++++++ llm_eval/registry.py | 24 +++++ pyproject.toml | 19 ++++ scripts/download_mmlu.py | 39 ++++++++ 8 files changed, 398 insertions(+) create mode 100644 .gitignore create mode 100644 llm_eval/__init__.py create mode 100644 llm_eval/base.py create mode 100644 llm_eval/cli.py create mode 100644 llm_eval/mmlu.py create mode 100644 llm_eval/registry.py create mode 100644 pyproject.toml create mode 100644 scripts/download_mmlu.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ab6d2ac --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +*.pyc +*.egg-info/ +dist/ +data/ diff --git a/llm_eval/__init__.py b/llm_eval/__init__.py new file mode 100644 index 0000000..c118683 --- /dev/null +++ b/llm_eval/__init__.py @@ -0,0 +1,10 @@ +from llm_eval.base import BaseEvaluator, EvalResult +from llm_eval.registry import EvalFactory +from llm_eval.mmlu import MMLUEvaluator + +__all__ = [ + "BaseEvaluator", + "EvalResult", + "EvalFactory", + "MMLUEvaluator", +] diff --git a/llm_eval/base.py b/llm_eval/base.py new file mode 100644 index 0000000..879f293 --- /dev/null +++ b/llm_eval/base.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class EvalResult: + task_name: str + num_samples: int + accuracy: float + results: List[Dict[str, Any]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + +class BaseEvaluator(ABC): + def __init__(self, api_base: str, api_key: str = "not-needed", **kwargs): + self.api_base = api_base.rstrip("/") + self.api_key = api_key + + @abstractmethod + def evaluate(self, data_path: str) -> EvalResult: ... + + @abstractmethod + def load_data(self, data_path: str) -> List[Dict[str, Any]]: ... diff --git a/llm_eval/cli.py b/llm_eval/cli.py new file mode 100644 index 0000000..298b3c7 --- /dev/null +++ b/llm_eval/cli.py @@ -0,0 +1,69 @@ +import argparse +import json +import os + +from llm_eval import EvalFactory + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="LLM Evaluation Benchmark (via HTTP API)") + parser.add_argument("--api_base", type=str, default="http://localhost:8000", + help="API base URL (default: http://localhost:8000)") + parser.add_argument("--api_key", type=str, default="not-needed", + help="API key") + parser.add_argument("--model", type=str, default="default", + help="Model name sent in request body") + parser.add_argument("--eval_type", type=str, default="mmlu", + choices=EvalFactory.list_registered(), + help="Evaluation task") + parser.add_argument("--data_path", type=str, required=True, + help="Dataset directory") + parser.add_argument("--subject", type=str, default="all", + help="Subject (default: all)") + parser.add_argument("--mode", type=str, default="logprobs", + choices=["logprobs", "generation"], + help="Scoring mode") + parser.add_argument("--output_file", type=str, default=None, + help="Path to save results JSON") + parser.add_argument("--max_retries", type=int, default=3) + return parser.parse_args() + + +def main(): + args = parse_args() + + evaluator = EvalFactory.create( + args.eval_type, + api_base=args.api_base, + api_key=args.api_key, + model=args.model, + subject=args.subject, + mode=args.mode, + max_retries=args.max_retries, + ) + + print(f"Running {args.eval_type} (subject={args.subject}, mode={args.mode})...") + print(f"API: {args.api_base}") + result = evaluator.evaluate(data_path=args.data_path) + + print(f"\n{'='*50}") + print(f"Task: {result.task_name}") + print(f"Samples: {result.num_samples}") + print(f"Acc: {result.accuracy:.4f} ({result.accuracy*100:.2f}%)") + print(f"{'='*50}") + + if args.output_file: + os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True) + with open(args.output_file, "w") as f: + json.dump({ + "task_name": result.task_name, + "num_samples": result.num_samples, + "accuracy": result.accuracy, + "metadata": result.metadata, + "results": result.results, + }, f, ensure_ascii=False, indent=2) + print(f"Saved to {args.output_file}") + + +if __name__ == "__main__": + main() diff --git a/llm_eval/mmlu.py b/llm_eval/mmlu.py new file mode 100644 index 0000000..5cf1842 --- /dev/null +++ b/llm_eval/mmlu.py @@ -0,0 +1,208 @@ +import csv +import os +import re +from typing import Any, Dict, List, Optional + +import requests + +from llm_eval.base import BaseEvaluator, EvalResult +from llm_eval.registry import EvalFactory + + +@EvalFactory.register("mmlu") +class MMLUEvaluator(BaseEvaluator): + """MMLU-style multiple-choice evaluator via HTTP API. + + Sends each question as a chat completion request and parses the + answer letter from the response. Two scoring modes: + + * ``logprobs`` (default) — requests per-token logprobs and picks the + highest-probability letter among A/B/C/D. + * ``generation`` — asks the model to reply with a single letter and + extracts it from the generated text. + """ + + LETTERS = ["A", "B", "C", "D"] + + def __init__( + self, + api_base: str, + api_key: str = "not-needed", + model: str = "default", + subject: str = "all", + mode: str = "logprobs", + max_retries: int = 3, + **kwargs, + ): + super().__init__(api_base, api_key=api_key) + self.model = model + self.subject = subject + self.mode = mode + self.max_retries = max_retries + + # ------------------------------------------------------------------ + # Data + # ------------------------------------------------------------------ + + def load_data(self, data_path: str) -> List[Dict[str, Any]]: + items = [] + if self.subject == "all": + for fname in sorted(os.listdir(data_path)): + if fname.endswith(".csv"): + items.extend(self._load_csv(os.path.join(data_path, fname), fname[:-4])) + else: + items = self._load_csv(os.path.join(data_path, f"{self.subject}.csv"), self.subject) + return items + + @staticmethod + def _load_csv(path: str, subject: str) -> List[Dict[str, Any]]: + items = [] + with open(path, newline="", encoding="utf-8") as f: + reader = csv.reader(f) + for row in reader: + if len(row) < 6: + continue + items.append(dict( + question=row[0], + choices=row[1:5], + answer=ord(row[5].strip().upper()) - ord("A"), + subject=subject, + )) + return items + + # ------------------------------------------------------------------ + # Prompt helpers + # ------------------------------------------------------------------ + + @staticmethod + def _build_messages(item: Dict[str, Any]) -> List[Dict[str, str]]: + ch = item["choices"] + prompt = ( + f"{item['question']}\n\n" + f"A. {ch[0]}\nB. {ch[1]}\nC. {ch[2]}\nD. {ch[3]}" + ) + return [ + {"role": "system", "content": ( + "Answer the multiple-choice question by responding with " + "only the single letter (A, B, C, or D) of the correct answer." + )}, + {"role": "user", "content": prompt}, + ] + + # ------------------------------------------------------------------ + # API call + # ------------------------------------------------------------------ + + def _chat(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: + body = dict(model=self.model, messages=messages, temperature=0.0) + + if self.mode == "logprobs": + body["logprobs"] = True + body["top_logprobs"] = 5 + body["max_tokens"] = 1 + else: + body["max_tokens"] = 5 + + for attempt in range(self.max_retries): + try: + resp = requests.post( + f"{self.api_base}/v1/chat/completions", + json=body, + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=60, + ) + resp.raise_for_status() + return resp.json() + except requests.RequestException as e: + if attempt == self.max_retries - 1: + raise + import time + time.sleep(1) + + # ------------------------------------------------------------------ + # Scoring + # ------------------------------------------------------------------ + + @staticmethod + def _extract_letter(text: str) -> Optional[int]: + m = re.search(r"\b([A-D])\b", text.strip().upper()) + return ord(m.group(1)) - ord("A") if m else None + + def _score_logprobs(self, data: List[Dict[str, Any]]) -> List[int]: + preds = [] + for item in data: + resp = self._chat(self._build_messages(item)) + choices = resp.get("choices", []) + if not choices: + preds.append(-1) + continue + + lp_data = choices[0].get("logprobs") + if not lp_data or not lp_data.get("top_logprobs"): + text = choices[0].get("message", {}).get("content", "") + preds.append(self._extract_letter(text) or -1) + continue + + top = lp_data["top_logprobs"][0] + best = -1 + best_lp = float("-inf") + for letter in self.LETTERS: + found = False + for token, lp in top.items(): + if token.strip().upper() == letter: + if lp > best_lp: + best_lp = lp + best = ord(letter) - ord("A") + found = True + break + if not found: + continue + + if best == -1: + text = choices[0].get("message", {}).get("content", "") + preds.append(self._extract_letter(text) or -1) + else: + preds.append(best) + + return preds + + def _score_generation(self, data: List[Dict[str, Any]]) -> List[int]: + preds = [] + for item in data: + resp = self._chat(self._build_messages(item)) + choices = resp.get("choices", []) + text = choices[0]["message"]["content"] if choices else "" + preds.append(self._extract_letter(text) or -1) + return preds + + # ------------------------------------------------------------------ + # Evaluate + # ------------------------------------------------------------------ + + def evaluate(self, data_path: str) -> EvalResult: + data = self.load_data(data_path) + scorer = self._score_logprobs if self.mode == "logprobs" else self._score_generation + preds = scorer(data) + + correct = 0 + results = [] + for item, pred in zip(data, preds): + ok = pred == item["answer"] + correct += int(ok) + results.append(dict( + subject=item["subject"], + question=item["question"], + choices=item["choices"], + answer=item["answer"], + prediction=pred, + correct=ok, + )) + + total = len(data) + return EvalResult( + task_name=f"mmlu_{self.subject}", + num_samples=total, + accuracy=correct / total if total else 0.0, + results=results, + metadata={"subject": self.subject, "mode": self.mode, "model": self.model}, + ) diff --git a/llm_eval/registry.py b/llm_eval/registry.py new file mode 100644 index 0000000..2e2736a --- /dev/null +++ b/llm_eval/registry.py @@ -0,0 +1,24 @@ +from typing import Dict, Type + +from llm_eval.base import BaseEvaluator + + +class EvalFactory: + _registry: Dict[str, Type[BaseEvaluator]] = {} + + @classmethod + def register(cls, name: str): + def decorator(klass: Type[BaseEvaluator]): + cls._registry[name] = klass + return klass + return decorator + + @classmethod + def create(cls, name: str, **kwargs) -> BaseEvaluator: + if name not in cls._registry: + raise KeyError(f"Unknown evaluator '{name}'. Available: {list(cls._registry)}") + return cls._registry[name](**kwargs) + + @classmethod + def list_registered(cls) -> list: + return list(cls._registry) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fd4f8d9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "llm-eval" +version = "0.1.0" +description = "LLM evaluation via HTTP API (MMLU, etc.)" +requires-python = ">=3.12" +dependencies = [ + "requests>=2.31", + "tqdm", +] + +[project.scripts] +llm-eval = "llm_eval.cli:main" + +[build-system] +requires = ["setuptools>=75.0"] +build-backend = "setuptools.backends._legacy:_Backend" + +[tool.setuptools.packages.find] +include = ["llm_eval*"] diff --git a/scripts/download_mmlu.py b/scripts/download_mmlu.py new file mode 100644 index 0000000..b6b8708 --- /dev/null +++ b/scripts/download_mmlu.py @@ -0,0 +1,39 @@ +"""Download MMLU dataset to data/mmlu/.""" + +import argparse +import os +import urllib.request +import zipfile + +REPO = "https://github.com/hendrycks/test/raw/master/" +FILES = [ + "auxiliary.zip", + "dev.zip", + "test.zip", + "val.zip", +] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, default="data/mmlu") + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + for fname in FILES: + url = REPO + fname + zip_path = os.path.join(args.output_dir, fname) + print(f"Downloading {url}...") + urllib.request.urlretrieve(url, zip_path) + + print(f"Extracting {zip_path}...") + with zipfile.ZipFile(zip_path, "r") as z: + z.extractall(args.output_dir) + os.remove(zip_path) + + print(f"MMLU data saved to {args.output_dir}") + + +if __name__ == "__main__": + main()