fix: 修复 CLI 参数缺失/重复、device_ids 越界、generate 参数名不一致、scheduler 时序、非流式截断等 bug

- train.py: 补上 --batch_size、--grpo_clip_eps,删除 3 处重复 --group_size
- generate.py: --model_dir 改为 --param_path 对齐 README
- automodel.py: from_pretrained 新增 strict 参数(默认 True)
- parallel/setup.py: 修复 device_ids 索引越界
- train_callback.py: scheduler.step() 移至 on_step_end
- test_train_strategy.py: 测试中补 optimizer.step()
- engine.py: 非流式改为循环等待所有任务完成,补 remove_task 清理
- scheduler.py: Task 添加 _pages_freed 标志,杜绝双重释放
- trainer.py: accumulation_steps=0 时 clamp 为 1
- tokenizer.py: save_pretrained 添加 _tokenizer is None 检查
- benchmark.py: 修复 ModelConfig 过时 import 路径
- inference/__init__.py: 修复 stale docstring
This commit is contained in:
ViperEkura 2026-05-09 14:36:42 +08:00
parent bc7c82977e
commit 283bcaf2ff
12 changed files with 49 additions and 30 deletions

View File

@ -3,7 +3,7 @@
Layers: Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum - scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager - cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) - sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints) - server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
""" """

View File

@ -408,13 +408,14 @@ class InferenceEngine:
Single string for one prompt, list of strings for batch. Single string for one prompt, list of strings for batch.
""" """
result = _Result(count=len(prompts)) result = _Result(count=len(prompts))
task_ids = []
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
def make_cb(idx): def make_cb(idx):
return lambda tok: result.append(tok, idx) return lambda tok: result.append(tok, idx)
self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=p, prompt=p,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
@ -422,8 +423,14 @@ class InferenceEngine:
top_k=top_k, top_k=top_k,
stream_callback=make_cb(i), stream_callback=make_cb(i),
) )
task_ids.append(task_id)
while result._completed < result._total:
result.wait(timeout=1.0)
for task_id in task_ids:
self.scheduler.remove_task(task_id)
result.wait()
res = result.get_results() res = result.get_results()
return res if is_batch else res[0] return res if is_batch else res[0]

View File

@ -56,6 +56,7 @@ class Task:
self.arrival_time = time.time() self.arrival_time = time.time()
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
self.stream_callback = stream_callback self.stream_callback = stream_callback
self._pages_freed: bool = False
@property @property
def next_pos(self) -> int: def next_pos(self) -> int:
@ -167,9 +168,11 @@ class InferenceScheduler:
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active: for task in removed_active:
self._free_pages(task.page_table) if not task._pages_freed:
task.page_table.clear() self._free_pages(task.page_table)
task.n_pages = 0 task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def _free_pages(self, indices: List[int]) -> None: def _free_pages(self, indices: List[int]) -> None:
for idx in indices: for idx in indices:
@ -185,9 +188,11 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens self._total_tokens += task.output_tokens
for task in finished: for task in finished:
self._free_pages(task.page_table) if not task._pages_freed:
task.page_table.clear() self._free_pages(task.page_table)
task.n_pages = 0 task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.active_tasks = [ self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED t for t in self.active_tasks if t.status != TaskStatus.FINISHED

View File

@ -84,6 +84,7 @@ class AutoModel(nn.Module):
cls, cls,
path: Union[str, Path], path: Union[str, Path],
disable_random_init: bool = True, disable_random_init: bool = True,
strict: bool = True,
) -> nn.Module: ) -> nn.Module:
model_path = Path(path) model_path = Path(path)
@ -106,7 +107,7 @@ class AutoModel(nn.Module):
weights_path = model_path / "model.safetensors" weights_path = model_path / "model.safetensors"
if weights_path.exists(): if weights_path.exists():
state_dict = st.load_file(str(weights_path)) state_dict = st.load_file(str(weights_path))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=strict)
return model return model

View File

@ -48,8 +48,9 @@ def setup_parallel(
if device_ids is None: if device_ids is None:
device_ids = [i for i in range(world_size)] device_ids = [i for i in range(world_size)]
rank = device_ids[rank % len(device_ids)] effective_rank = rank % len(device_ids)
device_id = torch.device(device_type, device_ids[rank]) device_id = torch.device(device_type, device_ids[effective_rank])
rank = device_ids[effective_rank]
os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port os.environ["MASTER_PORT"] = master_port

View File

@ -64,6 +64,11 @@ class AutoTokenizer:
save_path: Path to save the tokenizer save_path: Path to save the tokenizer
""" """
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
save_path = Path(save_path) save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)

View File

@ -104,7 +104,7 @@ class SchedulerCallback(TrainCallback):
if "initial_lr" not in group: if "initial_lr" not in group:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
def on_batch_end(self, context: TrainContext): def on_step_end(self, context: TrainContext):
if context.scheduler: if context.scheduler:
context.scheduler.step() context.scheduler.step()

View File

@ -68,8 +68,9 @@ class Trainer:
context.epoch = epoch context.epoch = epoch
self._call_callbacks("on_epoch_begin", context) self._call_callbacks("on_epoch_begin", context)
accumulation_steps = max(self.train_config.accumulation_steps, 1)
for batch in context.dataloader: for batch in context.dataloader:
if context.iteration % self.train_config.accumulation_steps == 0: if context.iteration % accumulation_steps == 0:
# 2. step # 2. step
self._call_callbacks("on_step_begin", context) self._call_callbacks("on_step_begin", context)
context.optimizer.step() context.optimizer.step()
@ -83,7 +84,7 @@ class Trainer:
context.iteration += 1 context.iteration += 1
# to make the loss normalized by accumulation steps # to make the loss normalized by accumulation steps
stand_loss = loss / self.train_config.accumulation_steps stand_loss = loss / accumulation_steps
stand_loss.backward() stand_loss.backward()
self._call_callbacks("on_batch_end", context) self._call_callbacks("on_batch_end", context)

View File

@ -6,8 +6,9 @@ from typing import Any, Dict
import torch import torch
from torch import Tensor from torch import Tensor
from astrai.config import ModelConfig
from astrai.inference.cache import PagedCache from astrai.inference.cache import PagedCache
from astrai.model.transformer import ModelConfig, Transformer from astrai.model.transformer import Transformer
@dataclass @dataclass

View File

@ -9,7 +9,7 @@ from astrai.tokenize import AutoTokenizer
def processor( def processor(
model_dir: str, param_path: str,
input_json_file: str, input_json_file: str,
output_json_file: str, output_json_file: str,
temperature: float, temperature: float,
@ -20,8 +20,8 @@ def processor(
max_tokens: int, max_tokens: int,
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(model_dir) model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(model_dir) tokenizer = AutoTokenizer.from_pretrained(param_path)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine # Create inference engine
@ -72,7 +72,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument( parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory." "--param_path", type=str, required=True, help="Path to the model directory."
) )
parser.add_argument( parser.add_argument(
"--input_json_file", "--input_json_file",

View File

@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train." "--n_epoch", type=int, default=1, help="Number of epochs to train."
) )
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
parser.add_argument( parser.add_argument(
"--accumulation_steps", "--accumulation_steps",
type=int, type=int,
@ -53,7 +53,7 @@ def parse_args() -> argparse.Namespace:
"--warmup_steps", "--warmup_steps",
type=int, type=int,
default=1000, default=1000,
help="Number of iters between warnings.", help="Number of warmup steps for LR scheduler.",
) )
parser.add_argument( parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training." "--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -98,23 +98,19 @@ def parse_args() -> argparse.Namespace:
"--window_size", "--window_size",
type=int, type=int,
default=None, default=None,
help="the max length of the input sequence.", help="Max length of the input sequence.",
) )
parser.add_argument( parser.add_argument(
"--stride", type=int, default=None, help="the step size of the input sequence." "--stride", type=int, default=None, help="Step size of the input sequence."
) )
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument( parser.add_argument(
"--on_policy", "--grpo_clip_eps", type=float, default=0.2, help="GRPO clipping epsilon."
action="store_true",
default=False,
help="Enable on-policy GRPO mode.",
) )
parser.add_argument( parser.add_argument(
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient." "--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
) )
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument( parser.add_argument(
"--label_smoothing", "--label_smoothing",
type=float, type=float,
@ -134,7 +130,6 @@ def parse_args() -> argparse.Namespace:
default="checkpoint", default="checkpoint",
help="Directory to save checkpoints.", help="Directory to save checkpoints.",
) )
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument( parser.add_argument(
"--grpo_sync_interval", "--grpo_sync_interval",
type=int, type=int,

View File

@ -72,6 +72,7 @@ def test_schedule_factory_random_configs():
# Test scheduler step functionality # Test scheduler step functionality
initial_lr = scheduler.get_last_lr() initial_lr = scheduler.get_last_lr()
optimizer.step()
scheduler.step() scheduler.step()
new_lr = scheduler.get_last_lr() new_lr = scheduler.get_last_lr()
@ -112,6 +113,7 @@ def test_schedule_factory_edge_cases():
# Test multiple steps # Test multiple steps
for _ in range(10): for _ in range(10):
optimizer.step()
scheduler.step() scheduler.step()
@ -136,6 +138,7 @@ def test_schedule_factory_state_persistence():
# Take a few steps # Take a few steps
for _ in range(5): for _ in range(5):
optimizer.step()
scheduler.step() scheduler.step()
# Save state # Save state