fix: 修复 CLI 参数缺失/重复、device_ids 越界、generate 参数名不一致、scheduler 时序、非流式截断等 bug
- train.py: 补上 --batch_size、--grpo_clip_eps,删除 3 处重复 --group_size - generate.py: --model_dir 改为 --param_path 对齐 README - automodel.py: from_pretrained 新增 strict 参数(默认 True) - parallel/setup.py: 修复 device_ids 索引越界 - train_callback.py: scheduler.step() 移至 on_step_end - test_train_strategy.py: 测试中补 optimizer.step() - engine.py: 非流式改为循环等待所有任务完成,补 remove_task 清理 - scheduler.py: Task 添加 _pages_freed 标志,杜绝双重释放 - trainer.py: accumulation_steps=0 时 clamp 为 1 - tokenizer.py: save_pretrained 添加 _tokenizer is None 检查 - benchmark.py: 修复 ModelConfig 过时 import 路径 - inference/__init__.py: 修复 stale docstring
This commit is contained in:
parent
bc7c82977e
commit
283bcaf2ff
|
|
@ -3,7 +3,7 @@
|
|||
Layers:
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
||||
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
|
||||
- cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
|
||||
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -408,13 +408,14 @@ class InferenceEngine:
|
|||
Single string for one prompt, list of strings for batch.
|
||||
"""
|
||||
result = _Result(count=len(prompts))
|
||||
task_ids = []
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
|
||||
def make_cb(idx):
|
||||
return lambda tok: result.append(tok, idx)
|
||||
|
||||
self.scheduler.add_task(
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=p,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
|
|
@ -422,8 +423,14 @@ class InferenceEngine:
|
|||
top_k=top_k,
|
||||
stream_callback=make_cb(i),
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
||||
while result._completed < result._total:
|
||||
result.wait(timeout=1.0)
|
||||
|
||||
for task_id in task_ids:
|
||||
self.scheduler.remove_task(task_id)
|
||||
|
||||
result.wait()
|
||||
res = result.get_results()
|
||||
return res if is_batch else res[0]
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class Task:
|
|||
self.arrival_time = time.time()
|
||||
self.finish_time: Optional[float] = None
|
||||
self.stream_callback = stream_callback
|
||||
self._pages_freed: bool = False
|
||||
|
||||
@property
|
||||
def next_pos(self) -> int:
|
||||
|
|
@ -167,9 +168,11 @@ class InferenceScheduler:
|
|||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||
|
||||
for task in removed_active:
|
||||
if not task._pages_freed:
|
||||
self._free_pages(task.page_table)
|
||||
task.page_table.clear()
|
||||
task.n_pages = 0
|
||||
task._pages_freed = True
|
||||
|
||||
def _free_pages(self, indices: List[int]) -> None:
|
||||
for idx in indices:
|
||||
|
|
@ -185,9 +188,11 @@ class InferenceScheduler:
|
|||
self._total_tokens += task.output_tokens
|
||||
|
||||
for task in finished:
|
||||
if not task._pages_freed:
|
||||
self._free_pages(task.page_table)
|
||||
task.page_table.clear()
|
||||
task.n_pages = 0
|
||||
task._pages_freed = True
|
||||
|
||||
self.active_tasks = [
|
||||
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class AutoModel(nn.Module):
|
|||
cls,
|
||||
path: Union[str, Path],
|
||||
disable_random_init: bool = True,
|
||||
strict: bool = True,
|
||||
) -> nn.Module:
|
||||
|
||||
model_path = Path(path)
|
||||
|
|
@ -106,7 +107,7 @@ class AutoModel(nn.Module):
|
|||
weights_path = model_path / "model.safetensors"
|
||||
if weights_path.exists():
|
||||
state_dict = st.load_file(str(weights_path))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -48,8 +48,9 @@ def setup_parallel(
|
|||
if device_ids is None:
|
||||
device_ids = [i for i in range(world_size)]
|
||||
|
||||
rank = device_ids[rank % len(device_ids)]
|
||||
device_id = torch.device(device_type, device_ids[rank])
|
||||
effective_rank = rank % len(device_ids)
|
||||
device_id = torch.device(device_type, device_ids[effective_rank])
|
||||
rank = device_ids[effective_rank]
|
||||
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = master_port
|
||||
|
|
|
|||
|
|
@ -64,6 +64,11 @@ class AutoTokenizer:
|
|||
save_path: Path to save the tokenizer
|
||||
"""
|
||||
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError(
|
||||
"Tokenizer not initialized. Load or create a tokenizer first."
|
||||
)
|
||||
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class SchedulerCallback(TrainCallback):
|
|||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
def on_step_end(self, context: TrainContext):
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
|
|
|
|||
|
|
@ -68,8 +68,9 @@ class Trainer:
|
|||
context.epoch = epoch
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||
for batch in context.dataloader:
|
||||
if context.iteration % self.train_config.accumulation_steps == 0:
|
||||
if context.iteration % accumulation_steps == 0:
|
||||
# 2. step
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
|
|
@ -83,7 +84,7 @@ class Trainer:
|
|||
context.iteration += 1
|
||||
|
||||
# to make the loss normalized by accumulation steps
|
||||
stand_loss = loss / self.train_config.accumulation_steps
|
||||
stand_loss = loss / accumulation_steps
|
||||
stand_loss.backward()
|
||||
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ from typing import Any, Dict
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.inference.cache import PagedCache
|
||||
from astrai.model.transformer import ModelConfig, Transformer
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from astrai.tokenize import AutoTokenizer
|
|||
|
||||
|
||||
def processor(
|
||||
model_dir: str,
|
||||
param_path: str,
|
||||
input_json_file: str,
|
||||
output_json_file: str,
|
||||
temperature: float,
|
||||
|
|
@ -20,8 +20,8 @@ def processor(
|
|||
max_tokens: int,
|
||||
):
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Create inference engine
|
||||
|
|
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||
"--param_path", type=str, required=True, help="Path to the model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_json_file",
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||
)
|
||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
|
||||
parser.add_argument(
|
||||
"--accumulation_steps",
|
||||
type=int,
|
||||
|
|
@ -53,7 +53,7 @@ def parse_args() -> argparse.Namespace:
|
|||
"--warmup_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iters between warnings.",
|
||||
help="Number of warmup steps for LR scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
|
||||
|
|
@ -98,23 +98,19 @@ def parse_args() -> argparse.Namespace:
|
|||
"--window_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="the max length of the input sequence.",
|
||||
help="Max length of the input sequence.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride", type=int, default=None, help="the step size of the input sequence."
|
||||
"--stride", type=int, default=None, help="Step size of the input sequence."
|
||||
)
|
||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||
parser.add_argument(
|
||||
"--on_policy",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable on-policy GRPO mode.",
|
||||
"--grpo_clip_eps", type=float, default=0.2, help="GRPO clipping epsilon."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
|
||||
)
|
||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||
parser.add_argument(
|
||||
"--label_smoothing",
|
||||
type=float,
|
||||
|
|
@ -134,7 +130,6 @@ def parse_args() -> argparse.Namespace:
|
|||
default="checkpoint",
|
||||
help="Directory to save checkpoints.",
|
||||
)
|
||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||
parser.add_argument(
|
||||
"--grpo_sync_interval",
|
||||
type=int,
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ def test_schedule_factory_random_configs():
|
|||
|
||||
# Test scheduler step functionality
|
||||
initial_lr = scheduler.get_last_lr()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
new_lr = scheduler.get_last_lr()
|
||||
|
||||
|
|
@ -112,6 +113,7 @@ def test_schedule_factory_edge_cases():
|
|||
|
||||
# Test multiple steps
|
||||
for _ in range(10):
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
|
||||
|
|
@ -136,6 +138,7 @@ def test_schedule_factory_state_persistence():
|
|||
|
||||
# Take a few steps
|
||||
for _ in range(5):
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# Save state
|
||||
|
|
|
|||
Loading…
Reference in New Issue