add all project source files
This commit is contained in:
parent
d8b83a175b
commit
ac814e5c52
|
|
@ -0,0 +1,5 @@
|
|||
__pycache__/
|
||||
*.pyc
|
||||
*.egg-info/
|
||||
dist/
|
||||
data/
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from llm_eval.base import BaseEvaluator, EvalResult
|
||||
from llm_eval.registry import EvalFactory
|
||||
from llm_eval.mmlu import MMLUEvaluator
|
||||
|
||||
__all__ = [
|
||||
"BaseEvaluator",
|
||||
"EvalResult",
|
||||
"EvalFactory",
|
||||
"MMLUEvaluator",
|
||||
]
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
task_name: str
|
||||
num_samples: int
|
||||
accuracy: float
|
||||
results: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseEvaluator(ABC):
|
||||
def __init__(self, api_base: str, api_key: str = "not-needed", **kwargs):
|
||||
self.api_base = api_base.rstrip("/")
|
||||
self.api_key = api_key
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, data_path: str) -> EvalResult: ...
|
||||
|
||||
@abstractmethod
|
||||
def load_data(self, data_path: str) -> List[Dict[str, Any]]: ...
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from llm_eval import EvalFactory
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="LLM Evaluation Benchmark (via HTTP API)")
|
||||
parser.add_argument("--api_base", type=str, default="http://localhost:8000",
|
||||
help="API base URL (default: http://localhost:8000)")
|
||||
parser.add_argument("--api_key", type=str, default="not-needed",
|
||||
help="API key")
|
||||
parser.add_argument("--model", type=str, default="default",
|
||||
help="Model name sent in request body")
|
||||
parser.add_argument("--eval_type", type=str, default="mmlu",
|
||||
choices=EvalFactory.list_registered(),
|
||||
help="Evaluation task")
|
||||
parser.add_argument("--data_path", type=str, required=True,
|
||||
help="Dataset directory")
|
||||
parser.add_argument("--subject", type=str, default="all",
|
||||
help="Subject (default: all)")
|
||||
parser.add_argument("--mode", type=str, default="logprobs",
|
||||
choices=["logprobs", "generation"],
|
||||
help="Scoring mode")
|
||||
parser.add_argument("--output_file", type=str, default=None,
|
||||
help="Path to save results JSON")
|
||||
parser.add_argument("--max_retries", type=int, default=3)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
evaluator = EvalFactory.create(
|
||||
args.eval_type,
|
||||
api_base=args.api_base,
|
||||
api_key=args.api_key,
|
||||
model=args.model,
|
||||
subject=args.subject,
|
||||
mode=args.mode,
|
||||
max_retries=args.max_retries,
|
||||
)
|
||||
|
||||
print(f"Running {args.eval_type} (subject={args.subject}, mode={args.mode})...")
|
||||
print(f"API: {args.api_base}")
|
||||
result = evaluator.evaluate(data_path=args.data_path)
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Task: {result.task_name}")
|
||||
print(f"Samples: {result.num_samples}")
|
||||
print(f"Acc: {result.accuracy:.4f} ({result.accuracy*100:.2f}%)")
|
||||
print(f"{'='*50}")
|
||||
|
||||
if args.output_file:
|
||||
os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
|
||||
with open(args.output_file, "w") as f:
|
||||
json.dump({
|
||||
"task_name": result.task_name,
|
||||
"num_samples": result.num_samples,
|
||||
"accuracy": result.accuracy,
|
||||
"metadata": result.metadata,
|
||||
"results": result.results,
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
print(f"Saved to {args.output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
import csv
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llm_eval.base import BaseEvaluator, EvalResult
|
||||
from llm_eval.registry import EvalFactory
|
||||
|
||||
|
||||
@EvalFactory.register("mmlu")
|
||||
class MMLUEvaluator(BaseEvaluator):
|
||||
"""MMLU-style multiple-choice evaluator via HTTP API.
|
||||
|
||||
Sends each question as a chat completion request and parses the
|
||||
answer letter from the response. Two scoring modes:
|
||||
|
||||
* ``logprobs`` (default) — requests per-token logprobs and picks the
|
||||
highest-probability letter among A/B/C/D.
|
||||
* ``generation`` — asks the model to reply with a single letter and
|
||||
extracts it from the generated text.
|
||||
"""
|
||||
|
||||
LETTERS = ["A", "B", "C", "D"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_base: str,
|
||||
api_key: str = "not-needed",
|
||||
model: str = "default",
|
||||
subject: str = "all",
|
||||
mode: str = "logprobs",
|
||||
max_retries: int = 3,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_base, api_key=api_key)
|
||||
self.model = model
|
||||
self.subject = subject
|
||||
self.mode = mode
|
||||
self.max_retries = max_retries
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Data
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_data(self, data_path: str) -> List[Dict[str, Any]]:
|
||||
items = []
|
||||
if self.subject == "all":
|
||||
for fname in sorted(os.listdir(data_path)):
|
||||
if fname.endswith(".csv"):
|
||||
items.extend(self._load_csv(os.path.join(data_path, fname), fname[:-4]))
|
||||
else:
|
||||
items = self._load_csv(os.path.join(data_path, f"{self.subject}.csv"), self.subject)
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _load_csv(path: str, subject: str) -> List[Dict[str, Any]]:
|
||||
items = []
|
||||
with open(path, newline="", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
if len(row) < 6:
|
||||
continue
|
||||
items.append(dict(
|
||||
question=row[0],
|
||||
choices=row[1:5],
|
||||
answer=ord(row[5].strip().upper()) - ord("A"),
|
||||
subject=subject,
|
||||
))
|
||||
return items
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _build_messages(item: Dict[str, Any]) -> List[Dict[str, str]]:
|
||||
ch = item["choices"]
|
||||
prompt = (
|
||||
f"{item['question']}\n\n"
|
||||
f"A. {ch[0]}\nB. {ch[1]}\nC. {ch[2]}\nD. {ch[3]}"
|
||||
)
|
||||
return [
|
||||
{"role": "system", "content": (
|
||||
"Answer the multiple-choice question by responding with "
|
||||
"only the single letter (A, B, C, or D) of the correct answer."
|
||||
)},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _chat(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
||||
body = dict(model=self.model, messages=messages, temperature=0.0)
|
||||
|
||||
if self.mode == "logprobs":
|
||||
body["logprobs"] = True
|
||||
body["top_logprobs"] = 5
|
||||
body["max_tokens"] = 1
|
||||
else:
|
||||
body["max_tokens"] = 5
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{self.api_base}/v1/chat/completions",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
timeout=60,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except requests.RequestException as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_letter(text: str) -> Optional[int]:
|
||||
m = re.search(r"\b([A-D])\b", text.strip().upper())
|
||||
return ord(m.group(1)) - ord("A") if m else None
|
||||
|
||||
def _score_logprobs(self, data: List[Dict[str, Any]]) -> List[int]:
|
||||
preds = []
|
||||
for item in data:
|
||||
resp = self._chat(self._build_messages(item))
|
||||
choices = resp.get("choices", [])
|
||||
if not choices:
|
||||
preds.append(-1)
|
||||
continue
|
||||
|
||||
lp_data = choices[0].get("logprobs")
|
||||
if not lp_data or not lp_data.get("top_logprobs"):
|
||||
text = choices[0].get("message", {}).get("content", "")
|
||||
preds.append(self._extract_letter(text) or -1)
|
||||
continue
|
||||
|
||||
top = lp_data["top_logprobs"][0]
|
||||
best = -1
|
||||
best_lp = float("-inf")
|
||||
for letter in self.LETTERS:
|
||||
found = False
|
||||
for token, lp in top.items():
|
||||
if token.strip().upper() == letter:
|
||||
if lp > best_lp:
|
||||
best_lp = lp
|
||||
best = ord(letter) - ord("A")
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
continue
|
||||
|
||||
if best == -1:
|
||||
text = choices[0].get("message", {}).get("content", "")
|
||||
preds.append(self._extract_letter(text) or -1)
|
||||
else:
|
||||
preds.append(best)
|
||||
|
||||
return preds
|
||||
|
||||
def _score_generation(self, data: List[Dict[str, Any]]) -> List[int]:
|
||||
preds = []
|
||||
for item in data:
|
||||
resp = self._chat(self._build_messages(item))
|
||||
choices = resp.get("choices", [])
|
||||
text = choices[0]["message"]["content"] if choices else ""
|
||||
preds.append(self._extract_letter(text) or -1)
|
||||
return preds
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Evaluate
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def evaluate(self, data_path: str) -> EvalResult:
|
||||
data = self.load_data(data_path)
|
||||
scorer = self._score_logprobs if self.mode == "logprobs" else self._score_generation
|
||||
preds = scorer(data)
|
||||
|
||||
correct = 0
|
||||
results = []
|
||||
for item, pred in zip(data, preds):
|
||||
ok = pred == item["answer"]
|
||||
correct += int(ok)
|
||||
results.append(dict(
|
||||
subject=item["subject"],
|
||||
question=item["question"],
|
||||
choices=item["choices"],
|
||||
answer=item["answer"],
|
||||
prediction=pred,
|
||||
correct=ok,
|
||||
))
|
||||
|
||||
total = len(data)
|
||||
return EvalResult(
|
||||
task_name=f"mmlu_{self.subject}",
|
||||
num_samples=total,
|
||||
accuracy=correct / total if total else 0.0,
|
||||
results=results,
|
||||
metadata={"subject": self.subject, "mode": self.mode, "model": self.model},
|
||||
)
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
from typing import Dict, Type
|
||||
|
||||
from llm_eval.base import BaseEvaluator
|
||||
|
||||
|
||||
class EvalFactory:
|
||||
_registry: Dict[str, Type[BaseEvaluator]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def decorator(klass: Type[BaseEvaluator]):
|
||||
cls._registry[name] = klass
|
||||
return klass
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, **kwargs) -> BaseEvaluator:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"Unknown evaluator '{name}'. Available: {list(cls._registry)}")
|
||||
return cls._registry[name](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
return list(cls._registry)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
[project]
|
||||
name = "llm-eval"
|
||||
version = "0.1.0"
|
||||
description = "LLM evaluation via HTTP API (MMLU, etc.)"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"requests>=2.31",
|
||||
"tqdm",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
llm-eval = "llm_eval.cli:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=75.0"]
|
||||
build-backend = "setuptools.backends._legacy:_Backend"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["llm_eval*"]
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
"""Download MMLU dataset to data/mmlu/."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
REPO = "https://github.com/hendrycks/test/raw/master/"
|
||||
FILES = [
|
||||
"auxiliary.zip",
|
||||
"dev.zip",
|
||||
"test.zip",
|
||||
"val.zip",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--output_dir", type=str, default="data/mmlu")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
for fname in FILES:
|
||||
url = REPO + fname
|
||||
zip_path = os.path.join(args.output_dir, fname)
|
||||
print(f"Downloading {url}...")
|
||||
urllib.request.urlretrieve(url, zip_path)
|
||||
|
||||
print(f"Extracting {zip_path}...")
|
||||
with zipfile.ZipFile(zip_path, "r") as z:
|
||||
z.extractall(args.output_dir)
|
||||
os.remove(zip_path)
|
||||
|
||||
print(f"MMLU data saved to {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue