fix: 修复 CLI 参数缺失/重复、device_ids 越界、generate 参数名不一致、scheduler 时序、非流式截断等 bug

- train.py: 补上 --batch_size、--grpo_clip_eps,删除 3 处重复 --group_size
- generate.py: --model_dir 改为 --param_path 对齐 README
- automodel.py: from_pretrained 新增 strict 参数(默认 True)
- parallel/setup.py: 修复 device_ids 索引越界
- train_callback.py: scheduler.step() 移至 on_step_end
- test_train_strategy.py: 测试中补 optimizer.step()
- engine.py: 非流式改为循环等待所有任务完成,补 remove_task 清理
- scheduler.py: Task 添加 _pages_freed 标志,杜绝双重释放
- trainer.py: accumulation_steps=0 时 clamp 为 1
- tokenizer.py: save_pretrained 添加 _tokenizer is None 检查
- benchmark.py: 修复 ModelConfig 过时 import 路径
- inference/__init__.py: 修复 stale docstring
This commit is contained in:
ViperEkura 2026-05-09 14:36:42 +08:00
parent bc7c82977e
commit 283bcaf2ff
12 changed files with 49 additions and 30 deletions

View File

@ -3,7 +3,7 @@
Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
- cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
"""

View File

@ -408,13 +408,14 @@ class InferenceEngine:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
task_ids = []
for i, p in enumerate(prompts):
def make_cb(idx):
return lambda tok: result.append(tok, idx)
self.scheduler.add_task(
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
@ -422,8 +423,14 @@ class InferenceEngine:
top_k=top_k,
stream_callback=make_cb(i),
)
task_ids.append(task_id)
while result._completed < result._total:
result.wait(timeout=1.0)
for task_id in task_ids:
self.scheduler.remove_task(task_id)
result.wait()
res = result.get_results()
return res if is_batch else res[0]

View File

@ -56,6 +56,7 @@ class Task:
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
self._pages_freed: bool = False
@property
def next_pos(self) -> int:
@ -167,9 +168,11 @@ class InferenceScheduler:
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
@ -185,9 +188,11 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens
for task in finished:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED

View File

@ -84,6 +84,7 @@ class AutoModel(nn.Module):
cls,
path: Union[str, Path],
disable_random_init: bool = True,
strict: bool = True,
) -> nn.Module:
model_path = Path(path)
@ -106,7 +107,7 @@ class AutoModel(nn.Module):
weights_path = model_path / "model.safetensors"
if weights_path.exists():
state_dict = st.load_file(str(weights_path))
model.load_state_dict(state_dict, strict=False)
model.load_state_dict(state_dict, strict=strict)
return model

View File

@ -48,8 +48,9 @@ def setup_parallel(
if device_ids is None:
device_ids = [i for i in range(world_size)]
rank = device_ids[rank % len(device_ids)]
device_id = torch.device(device_type, device_ids[rank])
effective_rank = rank % len(device_ids)
device_id = torch.device(device_type, device_ids[effective_rank])
rank = device_ids[effective_rank]
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port

View File

@ -64,6 +64,11 @@ class AutoTokenizer:
save_path: Path to save the tokenizer
"""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)

View File

@ -104,7 +104,7 @@ class SchedulerCallback(TrainCallback):
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
def on_batch_end(self, context: TrainContext):
def on_step_end(self, context: TrainContext):
if context.scheduler:
context.scheduler.step()

View File

@ -68,8 +68,9 @@ class Trainer:
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
accumulation_steps = max(self.train_config.accumulation_steps, 1)
for batch in context.dataloader:
if context.iteration % self.train_config.accumulation_steps == 0:
if context.iteration % accumulation_steps == 0:
# 2. step
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
@ -83,7 +84,7 @@ class Trainer:
context.iteration += 1
# to make the loss normalized by accumulation steps
stand_loss = loss / self.train_config.accumulation_steps
stand_loss = loss / accumulation_steps
stand_loss.backward()
self._call_callbacks("on_batch_end", context)

View File

@ -6,8 +6,9 @@ from typing import Any, Dict
import torch
from torch import Tensor
from astrai.config import ModelConfig
from astrai.inference.cache import PagedCache
from astrai.model.transformer import ModelConfig, Transformer
from astrai.model.transformer import Transformer
@dataclass

View File

@ -9,7 +9,7 @@ from astrai.tokenize import AutoTokenizer
def processor(
model_dir: str,
param_path: str,
input_json_file: str,
output_json_file: str,
temperature: float,
@ -20,8 +20,8 @@ def processor(
max_tokens: int,
):
# Load model and tokenizer
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(param_path)
model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine
@ -72,7 +72,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory."
"--param_path", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_json_file",

View File

@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
parser.add_argument(
"--accumulation_steps",
type=int,
@ -53,7 +53,7 @@ def parse_args() -> argparse.Namespace:
"--warmup_steps",
type=int,
default=1000,
help="Number of iters between warnings.",
help="Number of warmup steps for LR scheduler.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -98,23 +98,19 @@ def parse_args() -> argparse.Namespace:
"--window_size",
type=int,
default=None,
help="the max length of the input sequence.",
help="Max length of the input sequence.",
)
parser.add_argument(
"--stride", type=int, default=None, help="the step size of the input sequence."
"--stride", type=int, default=None, help="Step size of the input sequence."
)
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument(
"--on_policy",
action="store_true",
default=False,
help="Enable on-policy GRPO mode.",
"--grpo_clip_eps", type=float, default=0.2, help="GRPO clipping epsilon."
)
parser.add_argument(
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
)
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument(
"--label_smoothing",
type=float,
@ -134,7 +130,6 @@ def parse_args() -> argparse.Namespace:
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument(
"--grpo_sync_interval",
type=int,

View File

@ -72,6 +72,7 @@ def test_schedule_factory_random_configs():
# Test scheduler step functionality
initial_lr = scheduler.get_last_lr()
optimizer.step()
scheduler.step()
new_lr = scheduler.get_last_lr()
@ -112,6 +113,7 @@ def test_schedule_factory_edge_cases():
# Test multiple steps
for _ in range(10):
optimizer.step()
scheduler.step()
@ -136,6 +138,7 @@ def test_schedule_factory_state_persistence():
# Take a few steps
for _ in range(5):
optimizer.step()
scheduler.step()
# Save state