From 283bcaf2ff62fa5e264f8d54c41a0d4448c8c3b4 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 9 May 2026 14:36:42 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20CLI=20=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E7=BC=BA=E5=A4=B1/=E9=87=8D=E5=A4=8D=E3=80=81device?= =?UTF-8?q?=5Fids=20=E8=B6=8A=E7=95=8C=E3=80=81generate=20=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E5=90=8D=E4=B8=8D=E4=B8=80=E8=87=B4=E3=80=81scheduler?= =?UTF-8?q?=20=E6=97=B6=E5=BA=8F=E3=80=81=E9=9D=9E=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E6=88=AA=E6=96=AD=E7=AD=89=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - train.py: 补上 --batch_size、--grpo_clip_eps,删除 3 处重复 --group_size - generate.py: --model_dir 改为 --param_path 对齐 README - automodel.py: from_pretrained 新增 strict 参数(默认 True) - parallel/setup.py: 修复 device_ids 索引越界 - train_callback.py: scheduler.step() 移至 on_step_end - test_train_strategy.py: 测试中补 optimizer.step() - engine.py: 非流式改为循环等待所有任务完成,补 remove_task 清理 - scheduler.py: Task 添加 _pages_freed 标志,杜绝双重释放 - trainer.py: accumulation_steps=0 时 clamp 为 1 - tokenizer.py: save_pretrained 添加 _tokenizer is None 检查 - benchmark.py: 修复 ModelConfig 过时 import 路径 - inference/__init__.py: 修复 stale docstring --- astrai/inference/__init__.py | 2 +- astrai/inference/engine.py | 11 +++++++++-- astrai/inference/scheduler.py | 17 +++++++++++------ astrai/model/automodel.py | 3 ++- astrai/parallel/setup.py | 5 +++-- astrai/tokenize/tokenizer.py | 5 +++++ astrai/trainer/train_callback.py | 2 +- astrai/trainer/trainer.py | 5 +++-- scripts/tools/benchmark.py | 3 ++- scripts/tools/generate.py | 8 ++++---- scripts/tools/train.py | 15 +++++---------- tests/trainer/test_train_strategy.py | 3 +++ 12 files changed, 49 insertions(+), 30 deletions(-) diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 9b32ffd..c95d819 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -3,7 +3,7 @@ Layers: - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) - scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum - - cache.py: Object Pool (SlotAllocator), PrefixCacheManager + - cache.py: PagedCache (page-table-indirected KV cache with alloc/free) - sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) - server.py: FastAPI HTTP server (OpenAI-compatible endpoints) """ diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 440c958..e2b2a0f 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -408,13 +408,14 @@ class InferenceEngine: Single string for one prompt, list of strings for batch. """ result = _Result(count=len(prompts)) + task_ids = [] for i, p in enumerate(prompts): def make_cb(idx): return lambda tok: result.append(tok, idx) - self.scheduler.add_task( + task_id = self.scheduler.add_task( prompt=p, max_tokens=max_tokens, temperature=temperature, @@ -422,8 +423,14 @@ class InferenceEngine: top_k=top_k, stream_callback=make_cb(i), ) + task_ids.append(task_id) + + while result._completed < result._total: + result.wait(timeout=1.0) + + for task_id in task_ids: + self.scheduler.remove_task(task_id) - result.wait() res = result.get_results() return res if is_batch else res[0] diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index e9614db..c6833aa 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -56,6 +56,7 @@ class Task: self.arrival_time = time.time() self.finish_time: Optional[float] = None self.stream_callback = stream_callback + self._pages_freed: bool = False @property def next_pos(self) -> int: @@ -167,9 +168,11 @@ class InferenceScheduler: self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] for task in removed_active: - self._free_pages(task.page_table) - task.page_table.clear() - task.n_pages = 0 + if not task._pages_freed: + self._free_pages(task.page_table) + task.page_table.clear() + task.n_pages = 0 + task._pages_freed = True def _free_pages(self, indices: List[int]) -> None: for idx in indices: @@ -185,9 +188,11 @@ class InferenceScheduler: self._total_tokens += task.output_tokens for task in finished: - self._free_pages(task.page_table) - task.page_table.clear() - task.n_pages = 0 + if not task._pages_freed: + self._free_pages(task.page_table) + task.page_table.clear() + task.n_pages = 0 + task._pages_freed = True self.active_tasks = [ t for t in self.active_tasks if t.status != TaskStatus.FINISHED diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index 6fa86b8..3cd6e8e 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -84,6 +84,7 @@ class AutoModel(nn.Module): cls, path: Union[str, Path], disable_random_init: bool = True, + strict: bool = True, ) -> nn.Module: model_path = Path(path) @@ -106,7 +107,7 @@ class AutoModel(nn.Module): weights_path = model_path / "model.safetensors" if weights_path.exists(): state_dict = st.load_file(str(weights_path)) - model.load_state_dict(state_dict, strict=False) + model.load_state_dict(state_dict, strict=strict) return model diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index 9a17f9c..b00a3ed 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -48,8 +48,9 @@ def setup_parallel( if device_ids is None: device_ids = [i for i in range(world_size)] - rank = device_ids[rank % len(device_ids)] - device_id = torch.device(device_type, device_ids[rank]) + effective_rank = rank % len(device_ids) + device_id = torch.device(device_type, device_ids[effective_rank]) + rank = device_ids[effective_rank] os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index 14fcb18..41d86bb 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -64,6 +64,11 @@ class AutoTokenizer: save_path: Path to save the tokenizer """ + if self._tokenizer is None: + raise RuntimeError( + "Tokenizer not initialized. Load or create a tokenizer first." + ) + save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 6381b31..c2ac8be 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -104,7 +104,7 @@ class SchedulerCallback(TrainCallback): if "initial_lr" not in group: group["initial_lr"] = group["lr"] - def on_batch_end(self, context: TrainContext): + def on_step_end(self, context: TrainContext): if context.scheduler: context.scheduler.step() diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index b7f2361..8b688f3 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -68,8 +68,9 @@ class Trainer: context.epoch = epoch self._call_callbacks("on_epoch_begin", context) + accumulation_steps = max(self.train_config.accumulation_steps, 1) for batch in context.dataloader: - if context.iteration % self.train_config.accumulation_steps == 0: + if context.iteration % accumulation_steps == 0: # 2. step self._call_callbacks("on_step_begin", context) context.optimizer.step() @@ -83,7 +84,7 @@ class Trainer: context.iteration += 1 # to make the loss normalized by accumulation steps - stand_loss = loss / self.train_config.accumulation_steps + stand_loss = loss / accumulation_steps stand_loss.backward() self._call_callbacks("on_batch_end", context) diff --git a/scripts/tools/benchmark.py b/scripts/tools/benchmark.py index 7c6e7d5..ad03496 100644 --- a/scripts/tools/benchmark.py +++ b/scripts/tools/benchmark.py @@ -6,8 +6,9 @@ from typing import Any, Dict import torch from torch import Tensor +from astrai.config import ModelConfig from astrai.inference.cache import PagedCache -from astrai.model.transformer import ModelConfig, Transformer +from astrai.model.transformer import Transformer @dataclass diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index ab54528..2931300 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -9,7 +9,7 @@ from astrai.tokenize import AutoTokenizer def processor( - model_dir: str, + param_path: str, input_json_file: str, output_json_file: str, temperature: float, @@ -20,8 +20,8 @@ def processor( max_tokens: int, ): # Load model and tokenizer - model = AutoModel.from_pretrained(model_dir) - tokenizer = AutoTokenizer.from_pretrained(model_dir) + model = AutoModel.from_pretrained(param_path) + tokenizer = AutoTokenizer.from_pretrained(param_path) model.to(device="cuda", dtype=torch.bfloat16) # Create inference engine @@ -72,7 +72,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.") parser.add_argument( - "--model_dir", type=str, required=True, help="Path to the model directory." + "--param_path", type=str, required=True, help="Path to the model directory." ) parser.add_argument( "--input_json_file", diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 5268c02..07b9452 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--n_epoch", type=int, default=1, help="Number of epochs to train." ) - parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.") parser.add_argument( "--accumulation_steps", type=int, @@ -53,7 +53,7 @@ def parse_args() -> argparse.Namespace: "--warmup_steps", type=int, default=1000, - help="Number of iters between warnings.", + help="Number of warmup steps for LR scheduler.", ) parser.add_argument( "--max_lr", type=float, default=3e-4, help="Max learning rate for training." @@ -98,23 +98,19 @@ def parse_args() -> argparse.Namespace: "--window_size", type=int, default=None, - help="the max length of the input sequence.", + help="Max length of the input sequence.", ) parser.add_argument( - "--stride", type=int, default=None, help="the step size of the input sequence." + "--stride", type=int, default=None, help="Step size of the input sequence." ) parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") parser.add_argument( - "--on_policy", - action="store_true", - default=False, - help="Enable on-policy GRPO mode.", + "--grpo_clip_eps", type=float, default=0.2, help="GRPO clipping epsilon." ) parser.add_argument( "--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient." ) - parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") parser.add_argument( "--label_smoothing", type=float, @@ -134,7 +130,6 @@ def parse_args() -> argparse.Namespace: default="checkpoint", help="Directory to save checkpoints.", ) - parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") parser.add_argument( "--grpo_sync_interval", type=int, diff --git a/tests/trainer/test_train_strategy.py b/tests/trainer/test_train_strategy.py index 2926f56..de43b2b 100644 --- a/tests/trainer/test_train_strategy.py +++ b/tests/trainer/test_train_strategy.py @@ -72,6 +72,7 @@ def test_schedule_factory_random_configs(): # Test scheduler step functionality initial_lr = scheduler.get_last_lr() + optimizer.step() scheduler.step() new_lr = scheduler.get_last_lr() @@ -112,6 +113,7 @@ def test_schedule_factory_edge_cases(): # Test multiple steps for _ in range(10): + optimizer.step() scheduler.step() @@ -136,6 +138,7 @@ def test_schedule_factory_state_persistence(): # Take a few steps for _ in range(5): + optimizer.step() scheduler.step() # Save state