fix : 修正类型标注与统一 CLI 参数命名

- AutoRegressiveLM.forward 返回类型标注 -> Dict[str, Tensor]
- EmbeddingEncoder 移除冗余 position_ids 自动创建
- CLI 脚本模型目录参数统一为 --param_path
This commit is contained in:
ViperEkura 2026-05-27 20:48:53 +08:00
parent 4145d35e3c
commit 2d5dc93b3d
5 changed files with 10 additions and 13 deletions

View File

@ -68,9 +68,6 @@ class EmbeddingEncoder(AutoModel):
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids) rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False) attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)

View File

@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional from typing import Any, Dict, Mapping, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -136,7 +136,7 @@ class AutoRegressiveLM(AutoModel):
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None, paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None,
) -> Tensor: ) -> Dict[str, Tensor]:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)

View File

@ -197,7 +197,7 @@ def evaluate_subject(
def main(): def main():
parser = argparse.ArgumentParser(description="MMLU evaluation") parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument( parser.add_argument(
"--model_dir", type=str, default="./params", help="Model directory" "--param_path", type=str, default="./params", help="Model directory"
) )
parser.add_argument( parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory" "--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
@ -228,8 +228,8 @@ def main():
if args.download or not os.path.exists(args.data_dir): if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir) download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.model_dir) model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_dir) tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device device = args.device
dtype = getattr(torch, args.dtype) dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype) model.to(device=device, dtype=dtype)

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file( def process_file(
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(model_dir) model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(param_path)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f: with open(input_file, "r", encoding="utf-8") as f:
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.") parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument( parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory." "--param_path", type=str, required=True, help="Path to the model directory."
) )
parser.add_argument( parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file." "--input_file", type=str, required=True, help="Path to the input file."

View File

@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development" "--reload", action="store_true", help="Enable auto-reload for development"
) )
parser.add_argument( parser.add_argument(
"--param-path", "--param_path",
type=Path, type=Path,
default=None, default=None,
help="Path to model parameters (default: project_root/params)", help="Path to model parameters (default: project_root/params)",