import csv import os import re from typing import Any, Dict, List, Optional import requests from llm_eval.base import BaseEvaluator, EvalResult from llm_eval.registry import EvalFactory @EvalFactory.register("mmlu") class MMLUEvaluator(BaseEvaluator): """MMLU-style multiple-choice evaluator via HTTP API. Sends each question as a chat completion request and parses the answer letter from the response. Two scoring modes: * ``logprobs`` (default) — requests per-token logprobs and picks the highest-probability letter among A/B/C/D. * ``generation`` — asks the model to reply with a single letter and extracts it from the generated text. """ LETTERS = ["A", "B", "C", "D"] def __init__( self, api_base: str, api_key: str = "not-needed", model: str = "default", subject: str = "all", mode: str = "logprobs", max_retries: int = 3, **kwargs, ): super().__init__(api_base, api_key=api_key) self.model = model self.subject = subject self.mode = mode self.max_retries = max_retries # ------------------------------------------------------------------ # Data # ------------------------------------------------------------------ def load_data(self, data_path: str) -> List[Dict[str, Any]]: items = [] if self.subject == "all": for fname in sorted(os.listdir(data_path)): if fname.endswith(".csv"): items.extend(self._load_csv(os.path.join(data_path, fname), fname[:-4])) else: items = self._load_csv(os.path.join(data_path, f"{self.subject}.csv"), self.subject) return items @staticmethod def _load_csv(path: str, subject: str) -> List[Dict[str, Any]]: items = [] with open(path, newline="", encoding="utf-8") as f: reader = csv.reader(f) for row in reader: if len(row) < 6: continue items.append(dict( question=row[0], choices=row[1:5], answer=ord(row[5].strip().upper()) - ord("A"), subject=subject, )) return items # ------------------------------------------------------------------ # Prompt helpers # ------------------------------------------------------------------ @staticmethod def _build_messages(item: Dict[str, Any]) -> List[Dict[str, str]]: ch = item["choices"] prompt = ( f"{item['question']}\n\n" f"A. {ch[0]}\nB. {ch[1]}\nC. {ch[2]}\nD. {ch[3]}" ) return [ {"role": "system", "content": ( "Answer the multiple-choice question by responding with " "only the single letter (A, B, C, or D) of the correct answer." )}, {"role": "user", "content": prompt}, ] # ------------------------------------------------------------------ # API call # ------------------------------------------------------------------ def _chat(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: body = dict(model=self.model, messages=messages, temperature=0.0) if self.mode == "logprobs": body["logprobs"] = True body["top_logprobs"] = 5 body["max_tokens"] = 1 else: body["max_tokens"] = 5 for attempt in range(self.max_retries): try: resp = requests.post( f"{self.api_base}/v1/chat/completions", json=body, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=60, ) resp.raise_for_status() return resp.json() except requests.RequestException as e: if attempt == self.max_retries - 1: raise import time time.sleep(1) # ------------------------------------------------------------------ # Scoring # ------------------------------------------------------------------ @staticmethod def _extract_letter(text: str) -> Optional[int]: m = re.search(r"\b([A-D])\b", text.strip().upper()) return ord(m.group(1)) - ord("A") if m else None def _score_logprobs(self, data: List[Dict[str, Any]]) -> List[int]: preds = [] for item in data: resp = self._chat(self._build_messages(item)) choices = resp.get("choices", []) if not choices: preds.append(-1) continue lp_data = choices[0].get("logprobs") if not lp_data or not lp_data.get("top_logprobs"): text = choices[0].get("message", {}).get("content", "") preds.append(self._extract_letter(text) or -1) continue top = lp_data["top_logprobs"][0] best = -1 best_lp = float("-inf") for letter in self.LETTERS: found = False for token, lp in top.items(): if token.strip().upper() == letter: if lp > best_lp: best_lp = lp best = ord(letter) - ord("A") found = True break if not found: continue if best == -1: text = choices[0].get("message", {}).get("content", "") preds.append(self._extract_letter(text) or -1) else: preds.append(best) return preds def _score_generation(self, data: List[Dict[str, Any]]) -> List[int]: preds = [] for item in data: resp = self._chat(self._build_messages(item)) choices = resp.get("choices", []) text = choices[0]["message"]["content"] if choices else "" preds.append(self._extract_letter(text) or -1) return preds # ------------------------------------------------------------------ # Evaluate # ------------------------------------------------------------------ def evaluate(self, data_path: str) -> EvalResult: data = self.load_data(data_path) scorer = self._score_logprobs if self.mode == "logprobs" else self._score_generation preds = scorer(data) correct = 0 results = [] for item, pred in zip(data, preds): ok = pred == item["answer"] correct += int(ok) results.append(dict( subject=item["subject"], question=item["question"], choices=item["choices"], answer=item["answer"], prediction=pred, correct=ok, )) total = len(data) return EvalResult( task_name=f"mmlu_{self.subject}", num_samples=total, accuracy=correct / total if total else 0.0, results=results, metadata={"subject": self.subject, "mode": self.mode, "model": self.model}, )