209 lines
7.1 KiB
Python
209 lines
7.1 KiB
Python
import csv
|
|
import os
|
|
import re
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import requests
|
|
|
|
from llm_eval.base import BaseEvaluator, EvalResult
|
|
from llm_eval.registry import EvalFactory
|
|
|
|
|
|
@EvalFactory.register("mmlu")
|
|
class MMLUEvaluator(BaseEvaluator):
|
|
"""MMLU-style multiple-choice evaluator via HTTP API.
|
|
|
|
Sends each question as a chat completion request and parses the
|
|
answer letter from the response. Two scoring modes:
|
|
|
|
* ``logprobs`` (default) — requests per-token logprobs and picks the
|
|
highest-probability letter among A/B/C/D.
|
|
* ``generation`` — asks the model to reply with a single letter and
|
|
extracts it from the generated text.
|
|
"""
|
|
|
|
LETTERS = ["A", "B", "C", "D"]
|
|
|
|
def __init__(
|
|
self,
|
|
api_base: str,
|
|
api_key: str = "not-needed",
|
|
model: str = "default",
|
|
subject: str = "all",
|
|
mode: str = "logprobs",
|
|
max_retries: int = 3,
|
|
**kwargs,
|
|
):
|
|
super().__init__(api_base, api_key=api_key)
|
|
self.model = model
|
|
self.subject = subject
|
|
self.mode = mode
|
|
self.max_retries = max_retries
|
|
|
|
# ------------------------------------------------------------------
|
|
# Data
|
|
# ------------------------------------------------------------------
|
|
|
|
def load_data(self, data_path: str) -> List[Dict[str, Any]]:
|
|
items = []
|
|
if self.subject == "all":
|
|
for fname in sorted(os.listdir(data_path)):
|
|
if fname.endswith(".csv"):
|
|
items.extend(self._load_csv(os.path.join(data_path, fname), fname[:-4]))
|
|
else:
|
|
items = self._load_csv(os.path.join(data_path, f"{self.subject}.csv"), self.subject)
|
|
return items
|
|
|
|
@staticmethod
|
|
def _load_csv(path: str, subject: str) -> List[Dict[str, Any]]:
|
|
items = []
|
|
with open(path, newline="", encoding="utf-8") as f:
|
|
reader = csv.reader(f)
|
|
for row in reader:
|
|
if len(row) < 6:
|
|
continue
|
|
items.append(dict(
|
|
question=row[0],
|
|
choices=row[1:5],
|
|
answer=ord(row[5].strip().upper()) - ord("A"),
|
|
subject=subject,
|
|
))
|
|
return items
|
|
|
|
# ------------------------------------------------------------------
|
|
# Prompt helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _build_messages(item: Dict[str, Any]) -> List[Dict[str, str]]:
|
|
ch = item["choices"]
|
|
prompt = (
|
|
f"{item['question']}\n\n"
|
|
f"A. {ch[0]}\nB. {ch[1]}\nC. {ch[2]}\nD. {ch[3]}"
|
|
)
|
|
return [
|
|
{"role": "system", "content": (
|
|
"Answer the multiple-choice question by responding with "
|
|
"only the single letter (A, B, C, or D) of the correct answer."
|
|
)},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
# ------------------------------------------------------------------
|
|
# API call
|
|
# ------------------------------------------------------------------
|
|
|
|
def _chat(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
|
body = dict(model=self.model, messages=messages, temperature=0.0)
|
|
|
|
if self.mode == "logprobs":
|
|
body["logprobs"] = True
|
|
body["top_logprobs"] = 5
|
|
body["max_tokens"] = 1
|
|
else:
|
|
body["max_tokens"] = 5
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
resp = requests.post(
|
|
f"{self.api_base}/v1/chat/completions",
|
|
json=body,
|
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
timeout=60,
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
except requests.RequestException as e:
|
|
if attempt == self.max_retries - 1:
|
|
raise
|
|
import time
|
|
time.sleep(1)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Scoring
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _extract_letter(text: str) -> Optional[int]:
|
|
m = re.search(r"\b([A-D])\b", text.strip().upper())
|
|
return ord(m.group(1)) - ord("A") if m else None
|
|
|
|
def _score_logprobs(self, data: List[Dict[str, Any]]) -> List[int]:
|
|
preds = []
|
|
for item in data:
|
|
resp = self._chat(self._build_messages(item))
|
|
choices = resp.get("choices", [])
|
|
if not choices:
|
|
preds.append(-1)
|
|
continue
|
|
|
|
lp_data = choices[0].get("logprobs")
|
|
if not lp_data or not lp_data.get("top_logprobs"):
|
|
text = choices[0].get("message", {}).get("content", "")
|
|
preds.append(self._extract_letter(text) or -1)
|
|
continue
|
|
|
|
top = lp_data["top_logprobs"][0]
|
|
best = -1
|
|
best_lp = float("-inf")
|
|
for letter in self.LETTERS:
|
|
found = False
|
|
for token, lp in top.items():
|
|
if token.strip().upper() == letter:
|
|
if lp > best_lp:
|
|
best_lp = lp
|
|
best = ord(letter) - ord("A")
|
|
found = True
|
|
break
|
|
if not found:
|
|
continue
|
|
|
|
if best == -1:
|
|
text = choices[0].get("message", {}).get("content", "")
|
|
preds.append(self._extract_letter(text) or -1)
|
|
else:
|
|
preds.append(best)
|
|
|
|
return preds
|
|
|
|
def _score_generation(self, data: List[Dict[str, Any]]) -> List[int]:
|
|
preds = []
|
|
for item in data:
|
|
resp = self._chat(self._build_messages(item))
|
|
choices = resp.get("choices", [])
|
|
text = choices[0]["message"]["content"] if choices else ""
|
|
preds.append(self._extract_letter(text) or -1)
|
|
return preds
|
|
|
|
# ------------------------------------------------------------------
|
|
# Evaluate
|
|
# ------------------------------------------------------------------
|
|
|
|
def evaluate(self, data_path: str) -> EvalResult:
|
|
data = self.load_data(data_path)
|
|
scorer = self._score_logprobs if self.mode == "logprobs" else self._score_generation
|
|
preds = scorer(data)
|
|
|
|
correct = 0
|
|
results = []
|
|
for item, pred in zip(data, preds):
|
|
ok = pred == item["answer"]
|
|
correct += int(ok)
|
|
results.append(dict(
|
|
subject=item["subject"],
|
|
question=item["question"],
|
|
choices=item["choices"],
|
|
answer=item["answer"],
|
|
prediction=pred,
|
|
correct=ok,
|
|
))
|
|
|
|
total = len(data)
|
|
return EvalResult(
|
|
task_name=f"mmlu_{self.subject}",
|
|
num_samples=total,
|
|
accuracy=correct / total if total else 0.0,
|
|
results=results,
|
|
metadata={"subject": self.subject, "mode": self.mode, "model": self.model},
|
|
)
|