Compare commits

..

No commits in common. "785d65436c53de0fcbc28416bad6a7a4b9d8607f" and "6c8533f1d278191ea74749adda20ef461245f49d" have entirely different histories.

16 changed files with 34 additions and 144 deletions

View File

@ -1,7 +1,7 @@
# AstrAI Dockerfile - Multi-stage Build (Optimized) # AstrAI Dockerfile - Multi-stage Build (Optimized)
# Build stage - use base image with minimal build tools # Build stage - use base image with minimal build tools
FROM ubuntu:24.04 AS builder FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS builder
WORKDIR /app WORKDIR /app
@ -18,7 +18,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN python3.12 -m venv --copies /opt/venv RUN python3.12 -m venv --copies /opt/venv
ENV PATH="/opt/venv/bin:$PATH" ENV PATH="/opt/venv/bin:$PATH"
# Copy source code and install (deps read from pyproject.toml) # Copy source code and install dependencies
COPY astrai/ ./astrai/ COPY astrai/ ./astrai/
COPY pyproject.toml . COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip \ RUN pip install --no-cache-dir --upgrade pip \
@ -26,14 +26,13 @@ RUN pip install --no-cache-dir --upgrade pip \
--extra-index-url https://download.pytorch.org/whl/cu126 --extra-index-url https://download.pytorch.org/whl/cu126
# Production stage # Production stage
FROM ubuntu:24.04 AS production FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS production
WORKDIR /app WORKDIR /app
# Install Python 3.12 runtime and healthcheck dependency # Install Python 3.12 runtime
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.12 \ python3.12 \
curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy virtual environment from builder # Copy virtual environment from builder

View File

@ -213,7 +213,7 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py python scripts/demo/generate_ar.py
``` ```
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6). Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd).
### Documentation ### Documentation

View File

@ -219,7 +219,7 @@ python scripts/demo/generate_batch.py
python scripts/demo/generate_ar.py python scripts/demo/generate_ar.py
``` ```
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。 观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。
### 文档 ### 文档

View File

@ -77,9 +77,6 @@ classDiagram
+int start_batch +int start_batch
+str ckpt_dir +str ckpt_dir
+int ckpt_interval +int ckpt_interval
+str log_dir
+int log_interval
+List[str] metrics
+int random_seed +int random_seed
+int num_workers +int num_workers
+Optional[int] prefetch_factor +Optional[int] prefetch_factor
@ -475,10 +472,6 @@ classDiagram
class CheckpointCallback { class CheckpointCallback {
+str save_dir +str save_dir
+int interval +int interval
+bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context) +_save_checkpoint(context)
+on_train_begin(context) +on_train_begin(context)
+on_batch_end(context) +on_batch_end(context)
@ -490,8 +483,6 @@ classDiagram
class ProgressBarCallback { class ProgressBarCallback {
+int num_epoch +int num_epoch
+int log_interval
+IO file
+on_epoch_begin(context) +on_epoch_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_epoch_end(context) +on_epoch_end(context)
@ -500,8 +491,6 @@ classDiagram
class MetricLoggerCallback { class MetricLoggerCallback {
+str log_dir +str log_dir
+int save_interval +int save_interval
+int log_interval
+List[str] metrics
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context) +on_error(context)
@ -698,7 +687,7 @@ classDiagram
} }
class SamplingPipeline { class SamplingPipeline {
+List[BaseSamplingStrategy] strategies +List strategies
+apply(logits, filter_value) Tensor +apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor +sample(logits, filter_value) Tensor
} }
@ -722,16 +711,16 @@ classDiagram
class ChatCompletionRequest { class ChatCompletionRequest {
+str model +str model
+List[ChatMessage] messages +List[ChatMessage] messages
+Optional[float] temperature +float temperature
+Optional[float] top_p +float top_p
+Optional[int] top_k +int top_k
+Optional[int] max_tokens +int max_tokens
+Optional[bool] stream +bool stream
+Optional[Union[str, List[str]]] stop +Optional[Union[str, List[str]]] stop
+Optional[int] n +Optional[int] n
+Optional[float] presence_penalty +Optional[float] presence_penalty
+Optional[float] frequency_penalty +Optional[float] frequency_penalty
+Optional[Dict[int, float]] logit_bias +Optional[Dict] logit_bias
+Optional[str] user +Optional[str] user
} }
@ -883,6 +872,7 @@ classDiagram
InferenceScheduler *-- KVCache InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
AutoRegressiveLM *-- DecoderBlock AutoRegressiveLM *-- DecoderBlock
AutoRegressiveLM *-- RotaryEmbedding AutoRegressiveLM *-- RotaryEmbedding
AutoRegressiveLM *-- Embedding AutoRegressiveLM *-- Embedding
@ -890,10 +880,9 @@ classDiagram
EmbeddingEncoder *-- RotaryEmbedding EmbeddingEncoder *-- RotaryEmbedding
EmbeddingEncoder *-- Embedding EmbeddingEncoder *-- Embedding
DecoderBlock *-- RMSNorm DecoderBlock *-- RMSNorm
BaseDataset o-- BaseStorage
ChatCompletionRequest *-- ChatMessage ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
%% --- Aggregation (weak ownership) --- %% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig AutoModel o-- BaseModelConfig
@ -901,9 +890,9 @@ classDiagram
TrainContext o-- BaseStrategy TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint TrainContext o-- Checkpoint
AutoTokenizer o-- ChatTemplate
KvcacheView o-- Storage KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy BaseFactory o-- Registry
BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) --- %% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects TrainConfig ..> BaseStrategy : selects

View File

@ -1,4 +1,4 @@
__version__ = "1.3.6" __version__ = "1.3.5"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (

View File

@ -11,6 +11,7 @@ __all__ = [
"BaseModelConfig", "BaseModelConfig",
"AutoRegressiveLMConfig", "AutoRegressiveLMConfig",
"EncoderConfig", "EncoderConfig",
"ModelConfig",
"ConfigFactory", "ConfigFactory",
"TrainConfig", "TrainConfig",
] ]

View File

@ -13,7 +13,7 @@ class BaseConfig:
d[fld.name] = v d[fld.name] = v
elif v is None: elif v is None:
d[fld.name] = None d[fld.name] = None
elif isinstance(v, (dict, list)): elif isinstance(v, dict):
try: try:
json.dumps(v) json.dumps(v)
d[fld.name] = v d[fld.name] = v

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Callable, List, Optional from typing import Callable, Optional
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -56,19 +56,6 @@ class TrainConfig(BaseConfig):
default=5000, metadata={"help": "Number of iterations between checkpoints."} default=5000, metadata={"help": "Number of iterations between checkpoints."}
) )
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting # dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."}) random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field( num_workers: int = field(

View File

@ -226,17 +226,6 @@ class OpenAIHandler(ProtocolHandler):
def create_response_id(self) -> str: def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}" return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def get_stop_sequences(self) -> List[str]:
stop = self.request.stop
if stop is None:
return []
return [stop] if isinstance(stop, str) else stop
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
return stop_checker.check(ctx.accumulated)
def format_stream_start(self, ctx: StreamContext) -> List[str]: def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [ return [
_sse_event( _sse_event(

View File

@ -163,4 +163,5 @@ def spawn_parallel_fn(
nprocs=world_size, nprocs=world_size,
start_method=start_method, start_method=start_method,
join=True, join=True,
daemon=True,
) )

View File

@ -38,7 +38,7 @@ class Checkpoint:
meta = { meta = {
"epoch": self.epoch, "epoch": self.epoch,
"iteration": self.iteration, "iteration": self.iteration,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "timestamp": time.time(),
} }
meta.update(self.meta) meta.update(self.meta)
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:

View File

@ -1,10 +1,9 @@
import json import json
import logging import logging
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable from typing import Callable, List, Optional, Protocol, runtime_checkable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -212,12 +211,8 @@ class ProgressBarCallback(TrainCallback):
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
def __init__( def __init__(self, num_epoch: int):
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
@only_on_rank(0) @only_on_rank(0)
@ -226,7 +221,6 @@ class ProgressBarCallback(TrainCallback):
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True, dynamic_ncols=True,
file=self.file,
) )
@only_on_rank(0) @only_on_rank(0)
@ -280,7 +274,7 @@ class MetricLoggerCallback(TrainCallback):
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { return {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.iteration, "iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics}, **{m: self._metric_funcs[m](context) for m in self.metrics},

View File

@ -36,14 +36,8 @@ class Trainer:
cfg.ckpt_interval, cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn, state_dict_fn=cfg.state_dict_fn,
), ),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch), CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"), CallbackFactory.create("validation"),
] ]

View File

@ -1,13 +1,12 @@
services: services:
server: server:
build: build: .
context: . image: astrai:latest
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports: ports:
- "8000:8000" - "8000:8000"
volumes: volumes:
- ./params:/app/params:ro - ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cuda command: python -m scripts.tools.server --port 8000 --device cuda
deploy: deploy:
resources: resources:
@ -26,14 +25,13 @@ services:
server-cpu: server-cpu:
profiles: [cpu] profiles: [cpu]
build: build: .
context: . image: astrai:latest
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports: ports:
- "8000:8000" - "8000:8000"
volumes: volumes:
- ./params:/app/params:ro - ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cpu command: python -m scripts.tools.server --port 8000 --device cpu
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"] test: ["CMD", "curl", "-f", "http://localhost:8000/health"]

View File

@ -16,7 +16,6 @@ NC='\033[0m' # No Color
IMAGE_NAME="astrai" IMAGE_NAME="astrai"
IMAGE_TAG="latest" IMAGE_TAG="latest"
REGISTRY="" REGISTRY=""
CONTAINER_ID=""
# Print colored messages # Print colored messages
print_info() { print_info() {
@ -176,10 +175,6 @@ main() {
PORT="$2" PORT="$2"
shift 2 shift 2
;; ;;
--container)
CONTAINER_ID="$2"
shift 2
;;
--gpu) --gpu)
GPU=true GPU=true
shift shift
@ -202,7 +197,6 @@ main() {
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)" echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
echo " --context PATH Build context (default: .)" echo " --context PATH Build context (default: .)"
echo " --port PORT Port for run (default: 8000)" echo " --port PORT Port for run (default: 8000)"
echo " --container ID Container ID for logs"
echo " --gpu Enable GPU support" echo " --gpu Enable GPU support"
echo " --help Show this help message" echo " --help Show this help message"
echo "" echo ""
@ -211,7 +205,6 @@ main() {
echo " $0 build --tag v1.0.0" echo " $0 build --tag v1.0.0"
echo " $0 run --port 8080" echo " $0 run --port 8080"
echo " $0 run --gpu" echo " $0 run --gpu"
echo " $0 logs --container abc123"
echo " $0 push --registry ghcr.io/username" echo " $0 push --registry ghcr.io/username"
exit 0 exit 0
;; ;;
@ -244,7 +237,7 @@ main() {
show_info show_info
;; ;;
logs) logs)
show_logs "$CONTAINER_ID" show_logs "$2"
;; ;;
"") "")
print_error "No command specified. Use --help for usage" print_error "No command specified. Use --help for usage"

View File

@ -157,60 +157,5 @@ def test_messages_with_system(client, loaded_model):
assert data["type"] == "message" assert data["type"] == "message"
def test_chat_completions_stop_sequence(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": False,
"stop": ["X"],
},
)
assert response.status_code == 200
data = response.json()
content = data["choices"][0]["message"]["content"]
assert "X" in content
assert "world" not in content
def test_chat_completions_stop_sequence_stream(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": True,
"stop": ["X"],
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "Hello" in content
assert "world" not in content
assert any(
"finish_reason" in line for line in content.split("\n") if "stop" in line
)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])