Compare commits

..

6 Commits

Author SHA1 Message Date
ViperEkura 785d65436c fix: 修复 to_dict list 类型丢失与 OpenAI stop 参数失效
- to_dict() 增加 list 类型序列化支持,metrics 等字段不再丢失
- OpenAIHandler 补充 get_stop_sequences/on_token,读取 request.stop 并检测停止序列
- 文档类图补充缺失字段、修正关系分类、ChatCompletionRequest 字段增加 Optional
2026-05-19 21:07:07 +08:00
ViperEkura 64be81b7b3 feat: ProgressBarCallback 支持日志行输出到 stdout
- serialization 和 metric_logger 的 timestamp 统一使用 ISO 8601 格式
- ProgressBarCallback 新增 log_interval/file 参数,默认输出到 sys.stdout
2026-05-19 19:12:38 +08:00
ViperEkura 45479b5731 feat: metric 参数通过 TrainConfig 传递
- TrainConfig 新增 log_dir/log_interval/metrics 配置字段

- metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤
2026-05-19 17:50:24 +08:00
ViperEkura e0a3337c22 docs: 更新视频链接 2026-05-19 17:34:01 +08:00
ViperEkura 812238060b fix: docker-compose UID/GID 添加默认值,修复 docker.sh logs 命令 2026-05-18 14:24:00 +08:00
ViperEkura 14b0d56197 fix: 修复无法创建子进程的问题
- mp.start_processes daemon=False
2026-05-18 09:40:32 +08:00
16 changed files with 144 additions and 34 deletions

View File

@ -1,7 +1,7 @@
# AstrAI Dockerfile - Multi-stage Build (Optimized) # AstrAI Dockerfile - Multi-stage Build (Optimized)
# Build stage - use base image with minimal build tools # Build stage - use base image with minimal build tools
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS builder FROM ubuntu:24.04 AS builder
WORKDIR /app WORKDIR /app
@ -18,7 +18,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN python3.12 -m venv --copies /opt/venv RUN python3.12 -m venv --copies /opt/venv
ENV PATH="/opt/venv/bin:$PATH" ENV PATH="/opt/venv/bin:$PATH"
# Copy source code and install dependencies # Copy source code and install (deps read from pyproject.toml)
COPY astrai/ ./astrai/ COPY astrai/ ./astrai/
COPY pyproject.toml . COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip \ RUN pip install --no-cache-dir --upgrade pip \
@ -26,13 +26,14 @@ RUN pip install --no-cache-dir --upgrade pip \
--extra-index-url https://download.pytorch.org/whl/cu126 --extra-index-url https://download.pytorch.org/whl/cu126
# Production stage # Production stage
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS production FROM ubuntu:24.04 AS production
WORKDIR /app WORKDIR /app
# Install Python 3.12 runtime # Install Python 3.12 runtime and healthcheck dependency
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.12 \ python3.12 \
curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy virtual environment from builder # Copy virtual environment from builder

View File

@ -213,7 +213,7 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py python scripts/demo/generate_ar.py
``` ```
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd). Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
### Documentation ### Documentation

View File

@ -219,7 +219,7 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py python scripts/demo/generate_ar.py
``` ```
观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。 观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
### 文档 ### 文档

View File

@ -77,6 +77,9 @@ classDiagram
+int start_batch +int start_batch
+str ckpt_dir +str ckpt_dir
+int ckpt_interval +int ckpt_interval
+str log_dir
+int log_interval
+List[str] metrics
+int random_seed +int random_seed
+int num_workers +int num_workers
+Optional[int] prefetch_factor +Optional[int] prefetch_factor
@ -472,6 +475,10 @@ classDiagram
class CheckpointCallback { class CheckpointCallback {
+str save_dir +str save_dir
+int interval +int interval
+bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context) +_save_checkpoint(context)
+on_train_begin(context) +on_train_begin(context)
+on_batch_end(context) +on_batch_end(context)
@ -483,6 +490,8 @@ classDiagram
class ProgressBarCallback { class ProgressBarCallback {
+int num_epoch +int num_epoch
+int log_interval
+IO file
+on_epoch_begin(context) +on_epoch_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_epoch_end(context) +on_epoch_end(context)
@ -491,6 +500,8 @@ classDiagram
class MetricLoggerCallback { class MetricLoggerCallback {
+str log_dir +str log_dir
+int save_interval +int save_interval
+int log_interval
+List[str] metrics
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context) +on_error(context)
@ -687,7 +698,7 @@ classDiagram
} }
class SamplingPipeline { class SamplingPipeline {
+List strategies +List[BaseSamplingStrategy] strategies
+apply(logits, filter_value) Tensor +apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor +sample(logits, filter_value) Tensor
} }
@ -711,16 +722,16 @@ classDiagram
class ChatCompletionRequest { class ChatCompletionRequest {
+str model +str model
+List[ChatMessage] messages +List[ChatMessage] messages
+float temperature +Optional[float] temperature
+float top_p +Optional[float] top_p
+int top_k +Optional[int] top_k
+int max_tokens +Optional[int] max_tokens
+bool stream +Optional[bool] stream
+Optional[Union[str, List[str]]] stop +Optional[Union[str, List[str]]] stop
+Optional[int] n +Optional[int] n
+Optional[float] presence_penalty +Optional[float] presence_penalty
+Optional[float] frequency_penalty +Optional[float] frequency_penalty
+Optional[Dict] logit_bias +Optional[Dict[int, float]] logit_bias
+Optional[str] user +Optional[str] user
} }
@ -872,7 +883,6 @@ classDiagram
InferenceScheduler *-- KVCache InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
AutoRegressiveLM *-- DecoderBlock AutoRegressiveLM *-- DecoderBlock
AutoRegressiveLM *-- RotaryEmbedding AutoRegressiveLM *-- RotaryEmbedding
AutoRegressiveLM *-- Embedding AutoRegressiveLM *-- Embedding
@ -880,9 +890,10 @@ classDiagram
EmbeddingEncoder *-- RotaryEmbedding EmbeddingEncoder *-- RotaryEmbedding
EmbeddingEncoder *-- Embedding EmbeddingEncoder *-- Embedding
DecoderBlock *-- RMSNorm DecoderBlock *-- RMSNorm
BaseDataset o-- BaseStorage
ChatCompletionRequest *-- ChatMessage ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
%% --- Aggregation (weak ownership) --- %% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig AutoModel o-- BaseModelConfig
@ -890,9 +901,9 @@ classDiagram
TrainContext o-- BaseStrategy TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint TrainContext o-- Checkpoint
AutoTokenizer o-- ChatTemplate
KvcacheView o-- Storage KvcacheView o-- Storage
BaseFactory o-- Registry SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) --- %% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects TrainConfig ..> BaseStrategy : selects

View File

@ -1,4 +1,4 @@
__version__ = "1.3.5" __version__ = "1.3.6"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (

View File

@ -11,7 +11,6 @@ __all__ = [
"BaseModelConfig", "BaseModelConfig",
"AutoRegressiveLMConfig", "AutoRegressiveLMConfig",
"EncoderConfig", "EncoderConfig",
"ModelConfig",
"ConfigFactory", "ConfigFactory",
"TrainConfig", "TrainConfig",
] ]

View File

@ -13,7 +13,7 @@ class BaseConfig:
d[fld.name] = v d[fld.name] = v
elif v is None: elif v is None:
d[fld.name] = None d[fld.name] = None
elif isinstance(v, dict): elif isinstance(v, (dict, list)):
try: try:
json.dumps(v) json.dumps(v)
d[fld.name] = v d[fld.name] = v

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Callable, Optional from typing import Callable, List, Optional
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -56,6 +56,19 @@ class TrainConfig(BaseConfig):
default=5000, metadata={"help": "Number of iterations between checkpoints."} default=5000, metadata={"help": "Number of iterations between checkpoints."}
) )
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting # dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."}) random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field( num_workers: int = field(

View File

@ -226,6 +226,17 @@ class OpenAIHandler(ProtocolHandler):
def create_response_id(self) -> str: def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}" return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def get_stop_sequences(self) -> List[str]:
stop = self.request.stop
if stop is None:
return []
return [stop] if isinstance(stop, str) else stop
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
return stop_checker.check(ctx.accumulated)
def format_stream_start(self, ctx: StreamContext) -> List[str]: def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [ return [
_sse_event( _sse_event(

View File

@ -163,5 +163,4 @@ def spawn_parallel_fn(
nprocs=world_size, nprocs=world_size,
start_method=start_method, start_method=start_method,
join=True, join=True,
daemon=True,
) )

View File

@ -38,7 +38,7 @@ class Checkpoint:
meta = { meta = {
"epoch": self.epoch, "epoch": self.epoch,
"iteration": self.iteration, "iteration": self.iteration,
"timestamp": time.time(), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
} }
meta.update(self.meta) meta.update(self.meta)
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:

View File

@ -1,9 +1,10 @@
import json import json
import logging import logging
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -211,8 +212,12 @@ class ProgressBarCallback(TrainCallback):
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
def __init__(self, num_epoch: int): def __init__(
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
@only_on_rank(0) @only_on_rank(0)
@ -221,6 +226,7 @@ class ProgressBarCallback(TrainCallback):
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True, dynamic_ncols=True,
file=self.file,
) )
@only_on_rank(0) @only_on_rank(0)
@ -274,7 +280,7 @@ class MetricLoggerCallback(TrainCallback):
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.iteration, "iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics}, **{m: self._metric_funcs[m](context) for m in self.metrics},

View File

@ -36,8 +36,14 @@ class Trainer:
cfg.ckpt_interval, cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn, state_dict_fn=cfg.state_dict_fn,
), ),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch), CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"), CallbackFactory.create("validation"),
] ]

View File

@ -1,12 +1,13 @@
services: services:
server: server:
build: . build:
image: astrai:latest context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports: ports:
- "8000:8000" - "8000:8000"
volumes: volumes:
- ./params:/app/params:ro - ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cuda command: python -m scripts.tools.server --port 8000 --device cuda
deploy: deploy:
resources: resources:
@ -25,13 +26,14 @@ services:
server-cpu: server-cpu:
profiles: [cpu] profiles: [cpu]
build: . build:
image: astrai:latest context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports: ports:
- "8000:8000" - "8000:8000"
volumes: volumes:
- ./params:/app/params:ro - ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cpu command: python -m scripts.tools.server --port 8000 --device cpu
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"] test: ["CMD", "curl", "-f", "http://localhost:8000/health"]

View File

@ -16,6 +16,7 @@ NC='\033[0m' # No Color
IMAGE_NAME="astrai" IMAGE_NAME="astrai"
IMAGE_TAG="latest" IMAGE_TAG="latest"
REGISTRY="" REGISTRY=""
CONTAINER_ID=""
# Print colored messages # Print colored messages
print_info() { print_info() {
@ -175,6 +176,10 @@ main() {
PORT="$2" PORT="$2"
shift 2 shift 2
;; ;;
--container)
CONTAINER_ID="$2"
shift 2
;;
--gpu) --gpu)
GPU=true GPU=true
shift shift
@ -197,6 +202,7 @@ main() {
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)" echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
echo " --context PATH Build context (default: .)" echo " --context PATH Build context (default: .)"
echo " --port PORT Port for run (default: 8000)" echo " --port PORT Port for run (default: 8000)"
echo " --container ID Container ID for logs"
echo " --gpu Enable GPU support" echo " --gpu Enable GPU support"
echo " --help Show this help message" echo " --help Show this help message"
echo "" echo ""
@ -205,6 +211,7 @@ main() {
echo " $0 build --tag v1.0.0" echo " $0 build --tag v1.0.0"
echo " $0 run --port 8080" echo " $0 run --port 8080"
echo " $0 run --gpu" echo " $0 run --gpu"
echo " $0 logs --container abc123"
echo " $0 push --registry ghcr.io/username" echo " $0 push --registry ghcr.io/username"
exit 0 exit 0
;; ;;
@ -237,7 +244,7 @@ main() {
show_info show_info
;; ;;
logs) logs)
show_logs "$2" show_logs "$CONTAINER_ID"
;; ;;
"") "")
print_error "No command specified. Use --help for usage" print_error "No command specified. Use --help for usage"

View File

@ -157,5 +157,60 @@ def test_messages_with_system(client, loaded_model):
assert data["type"] == "message" assert data["type"] == "message"
def test_chat_completions_stop_sequence(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": False,
"stop": ["X"],
},
)
assert response.status_code == 200
data = response.json()
content = data["choices"][0]["message"]["content"]
assert "X" in content
assert "world" not in content
def test_chat_completions_stop_sequence_stream(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": True,
"stop": ["X"],
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "Hello" in content
assert "world" not in content
assert any(
"finish_reason" in line for line in content.split("\n") if "stop" in line
)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])