Compare commits

..

6 Commits

Author SHA1 Message Date
ViperEkura 785d65436c fix: 修复 to_dict list 类型丢失与 OpenAI stop 参数失效
- to_dict() 增加 list 类型序列化支持,metrics 等字段不再丢失
- OpenAIHandler 补充 get_stop_sequences/on_token,读取 request.stop 并检测停止序列
- 文档类图补充缺失字段、修正关系分类、ChatCompletionRequest 字段增加 Optional
2026-05-19 21:07:07 +08:00
ViperEkura 64be81b7b3 feat: ProgressBarCallback 支持日志行输出到 stdout
- serialization 和 metric_logger 的 timestamp 统一使用 ISO 8601 格式
- ProgressBarCallback 新增 log_interval/file 参数,默认输出到 sys.stdout
2026-05-19 19:12:38 +08:00
ViperEkura 45479b5731 feat: metric 参数通过 TrainConfig 传递
- TrainConfig 新增 log_dir/log_interval/metrics 配置字段

- metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤
2026-05-19 17:50:24 +08:00
ViperEkura e0a3337c22 docs: 更新视频链接 2026-05-19 17:34:01 +08:00
ViperEkura 812238060b fix: docker-compose UID/GID 添加默认值,修复 docker.sh logs 命令 2026-05-18 14:24:00 +08:00
ViperEkura 14b0d56197 fix: 修复无法创建子进程的问题
- mp.start_processes daemon=False
2026-05-18 09:40:32 +08:00
16 changed files with 144 additions and 34 deletions

View File

@ -1,7 +1,7 @@
# AstrAI Dockerfile - Multi-stage Build (Optimized)
# Build stage - use base image with minimal build tools
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS builder
FROM ubuntu:24.04 AS builder
WORKDIR /app
@ -18,7 +18,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN python3.12 -m venv --copies /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Copy source code and install dependencies
# Copy source code and install (deps read from pyproject.toml)
COPY astrai/ ./astrai/
COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip \
@ -26,13 +26,14 @@ RUN pip install --no-cache-dir --upgrade pip \
--extra-index-url https://download.pytorch.org/whl/cu126
# Production stage
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS production
FROM ubuntu:24.04 AS production
WORKDIR /app
# Install Python 3.12 runtime
# Install Python 3.12 runtime and healthcheck dependency
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.12 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy virtual environment from builder

View File

@ -213,7 +213,7 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py
```
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd).
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
### Documentation

View File

@ -219,7 +219,7 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py
```
观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
### 文档

View File

@ -77,6 +77,9 @@ classDiagram
+int start_batch
+str ckpt_dir
+int ckpt_interval
+str log_dir
+int log_interval
+List[str] metrics
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
@ -472,6 +475,10 @@ classDiagram
class CheckpointCallback {
+str save_dir
+int interval
+bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context)
@ -483,6 +490,8 @@ classDiagram
class ProgressBarCallback {
+int num_epoch
+int log_interval
+IO file
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
@ -491,6 +500,8 @@ classDiagram
class MetricLoggerCallback {
+str log_dir
+int save_interval
+int log_interval
+List[str] metrics
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
@ -687,7 +698,7 @@ classDiagram
}
class SamplingPipeline {
+List strategies
+List[BaseSamplingStrategy] strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
}
@ -711,16 +722,16 @@ classDiagram
class ChatCompletionRequest {
+str model
+List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+Optional[int] max_tokens
+Optional[bool] stream
+Optional[Union[str, List[str]]] stop
+Optional[int] n
+Optional[float] presence_penalty
+Optional[float] frequency_penalty
+Optional[Dict] logit_bias
+Optional[Dict[int, float]] logit_bias
+Optional[str] user
}
@ -872,7 +883,6 @@ classDiagram
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
AutoRegressiveLM *-- DecoderBlock
AutoRegressiveLM *-- RotaryEmbedding
AutoRegressiveLM *-- Embedding
@ -880,9 +890,10 @@ classDiagram
EmbeddingEncoder *-- RotaryEmbedding
EmbeddingEncoder *-- Embedding
DecoderBlock *-- RMSNorm
BaseDataset o-- BaseStorage
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
@ -890,9 +901,9 @@ classDiagram
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
AutoTokenizer o-- ChatTemplate
KvcacheView o-- Storage
BaseFactory o-- Registry
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects

View File

@ -1,4 +1,4 @@
__version__ = "1.3.5"
__version__ = "1.3.6"
__author__ = "ViperEkura"
from astrai.config import (

View File

@ -11,7 +11,6 @@ __all__ = [
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ModelConfig",
"ConfigFactory",
"TrainConfig",
]

View File

@ -13,7 +13,7 @@ class BaseConfig:
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, dict):
elif isinstance(v, (dict, list)):
try:
json.dumps(v)
d[fld.name] = v

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields
from typing import Callable, Optional
from typing import Callable, List, Optional
import torch.nn as nn
from torch.optim import Optimizer
@ -56,6 +56,19 @@ class TrainConfig(BaseConfig):
default=5000, metadata={"help": "Number of iterations between checkpoints."}
)
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field(

View File

@ -226,6 +226,17 @@ class OpenAIHandler(ProtocolHandler):
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def get_stop_sequences(self) -> List[str]:
stop = self.request.stop
if stop is None:
return []
return [stop] if isinstance(stop, str) else stop
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
return stop_checker.check(ctx.accumulated)
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(

View File

@ -163,5 +163,4 @@ def spawn_parallel_fn(
nprocs=world_size,
start_method=start_method,
join=True,
daemon=True,
)

View File

@ -38,7 +38,7 @@ class Checkpoint:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.time(),
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f:

View File

@ -1,9 +1,10 @@
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
@ -211,8 +212,12 @@ class ProgressBarCallback(TrainCallback):
Progress bar callback for trainer.
"""
def __init__(self, num_epoch: int):
def __init__(
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
):
self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None
@only_on_rank(0)
@ -221,6 +226,7 @@ class ProgressBarCallback(TrainCallback):
context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True,
file=self.file,
)
@only_on_rank(0)
@ -274,7 +280,7 @@ class MetricLoggerCallback(TrainCallback):
def _get_log_data(self, context: TrainContext):
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"epoch": context.epoch,
"iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics},

View File

@ -36,8 +36,14 @@ class Trainer:
cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn,
),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
]

View File

@ -1,12 +1,13 @@
services:
server:
build: .
image: astrai:latest
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cuda
deploy:
resources:
@ -25,13 +26,14 @@ services:
server-cpu:
profiles: [cpu]
build: .
image: astrai:latest
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cpu
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]

View File

@ -16,6 +16,7 @@ NC='\033[0m' # No Color
IMAGE_NAME="astrai"
IMAGE_TAG="latest"
REGISTRY=""
CONTAINER_ID=""
# Print colored messages
print_info() {
@ -175,6 +176,10 @@ main() {
PORT="$2"
shift 2
;;
--container)
CONTAINER_ID="$2"
shift 2
;;
--gpu)
GPU=true
shift
@ -197,6 +202,7 @@ main() {
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
echo " --context PATH Build context (default: .)"
echo " --port PORT Port for run (default: 8000)"
echo " --container ID Container ID for logs"
echo " --gpu Enable GPU support"
echo " --help Show this help message"
echo ""
@ -205,6 +211,7 @@ main() {
echo " $0 build --tag v1.0.0"
echo " $0 run --port 8080"
echo " $0 run --gpu"
echo " $0 logs --container abc123"
echo " $0 push --registry ghcr.io/username"
exit 0
;;
@ -237,7 +244,7 @@ main() {
show_info
;;
logs)
show_logs "$2"
show_logs "$CONTAINER_ID"
;;
"")
print_error "No command specified. Use --help for usage"

View File

@ -157,5 +157,60 @@ def test_messages_with_system(client, loaded_model):
assert data["type"] == "message"
def test_chat_completions_stop_sequence(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": False,
"stop": ["X"],
},
)
assert response.status_code == 200
data = response.json()
content = data["choices"][0]["message"]["content"]
assert "X" in content
assert "world" not in content
def test_chat_completions_stop_sequence_stream(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": True,
"stop": ["X"],
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "Hello" in content
assert "world" not in content
assert any(
"finish_reason" in line for line in content.split("\n") if "stop" in line
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])