fix : 修正类型标注与统一 CLI 参数命名
- AutoRegressiveLM.forward 返回类型标注 -> Dict[str, Tensor] - EmbeddingEncoder 移除冗余 position_ids 自动创建 - CLI 脚本模型目录参数统一为 --param_path
This commit is contained in:
parent
4145d35e3c
commit
2d5dc93b3d
|
|
@ -68,9 +68,6 @@ class EmbeddingEncoder(AutoModel):
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
|
||||||
|
|
||||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Mapping, Optional
|
from typing import Any, Dict, Mapping, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -136,7 +136,7 @@ class AutoRegressiveLM(AutoModel):
|
||||||
input_mask: Optional[Tensor] = None,
|
input_mask: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
position_ids: Optional[Tensor] = None,
|
position_ids: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Dict[str, Tensor]:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
|
||||||
|
|
@ -197,7 +197,7 @@ def evaluate_subject(
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_dir", type=str, default="./params", help="Model directory"
|
"--param_path", type=str, default="./params", help="Model directory"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
||||||
|
|
@ -228,8 +228,8 @@ def main():
|
||||||
if args.download or not os.path.exists(args.data_dir):
|
if args.download or not os.path.exists(args.data_dir):
|
||||||
download_mmlu(args.data_dir)
|
download_mmlu(args.data_dir)
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(args.model_dir)
|
model = AutoModel.from_pretrained(args.param_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
||||||
device = args.device
|
device = args.device
|
||||||
dtype = getattr(torch, args.dtype)
|
dtype = getattr(torch, args.dtype)
|
||||||
model.to(device=device, dtype=dtype)
|
model.to(device=device, dtype=dtype)
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||||
):
|
):
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model = AutoModel.from_pretrained(model_dir)
|
model = AutoModel.from_pretrained(param_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
with open(input_file, "r", encoding="utf-8") as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
|
|
@ -88,7 +88,7 @@ def process_file(
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
"--param_path", type=str, required=True, help="Path to the model directory."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input_file", type=str, required=True, help="Path to the input file."
|
"--input_file", type=str, required=True, help="Path to the input file."
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ def main():
|
||||||
"--reload", action="store_true", help="Enable auto-reload for development"
|
"--reload", action="store_true", help="Enable auto-reload for development"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--param-path",
|
"--param_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to model parameters (default: project_root/params)",
|
help="Path to model parameters (default: project_root/params)",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue