Compare commits

...

3 Commits

Author SHA1 Message Date
ViperEkura bc7c82977e feat: GRPO CLI 接入 + on-policy,OpenAI API top_k 参数化,补充训练参数表
- train.py 新增 --train_type=grpo 及参数 (--grpo_clip_eps, --grpo_kl_coef, --group_size, --grpo_sync_interval, --start_epoch)
- GRPOStrategy 统一 on-policy 模式,ratio = exp(logπ_θ - logπ_ref),PPO 裁剪目标,sync_interval 自动同步 ref_model
- ChatCompletionRequest 新增 top_k 参数,不再硬编码
- 补充 README 完整训练参数表(含此前缺失的 max_grad_norm / adamw / window_size / stride 等)
2026-05-09 12:22:33 +08:00
ViperEkura 34a511e36e feat: 新增 Docker Compose 一键部署,支持 GPU/CPU 双模式 2026-05-09 11:57:46 +08:00
ViperEkura d73f52a2f8 feat: 新增 Anthropic 兼容 /v1/messages API,移除旧版 /generate 端点
- 新增 /v1/messages 端点,兼容 Anthropic Messages API 格式
- 支持流式 SSE(message_start → content_block_delta → message_stop)
- 支持 system 顶层提示词与 stop_sequences 停止序列
- 新增 AnthropicMessage / MessagesRequest Pydantic 模型
- 移除旧版 /generate 端点及相关测试用例
- 更新 README.md / README-zh-CN.md / introduction.md 文档
2026-05-09 11:47:22 +08:00
9 changed files with 541 additions and 188 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@
!/.gitattributes !/.gitattributes
!/.dockerignore !/.dockerignore
!/Dockerfile !/Dockerfile
!/docker-compose.yml
!/assets/** !/assets/**
!/CONTRIBUTING.md !/CONTRIBUTING.md
!/LICENSE !/LICENSE

100
README.md
View File

@ -47,6 +47,7 @@
- 📦 **Lightweight**: Minimal dependencies, easy to deploy. - 📦 **Lightweight**: Minimal dependencies, easy to deploy.
- 🔬 **ResearchFriendly**: Modular design, easy to experiment with new ideas. - 🔬 **ResearchFriendly**: Modular design, easy to experiment with new ideas.
- 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets. - 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets.
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
### Quick Start ### Quick Start
@ -67,44 +68,48 @@ pip install -e ".[dev]"
#### Train a Model #### Train a Model
```bash ```bash
python scripts/tools/train.py \ python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
| `--random_seed` | Random seed | 3407 |
| `--num_workers` | DataLoader workers | 4 |
| `--window_size` | Max input sequence length | auto |
| `--stride` | Sequence stride | auto |
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
| `--dpo_beta` | DPO beta | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
| `--group_size` | GRPO group size | 4 |
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--start_epoch` | Start epoch (for resume) | 0 |
| `--start_batch` | Start batch (for resume) | 0 |
| `--nprocs` | Number of GPUs | 1 |
| `--device_type` | Device type | cuda |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Generate Text #### Generate Text
```bash ```bash
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### Training Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--num_workers` | DataLoader workers | 4 |
| `--nprocs` | Number of GPUs | 1 |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Docker #### Docker
Build and run with Docker (recommended for GPU environments): Build and run with Docker (recommended for GPU environments):
@ -125,13 +130,21 @@ docker run --gpus all -p 8000:8000 astrai:latest \
# Run with volume mount for data # Run with volume mount for data
docker run --gpus all -v /path/to/data:/data -it astrai:latest docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker Compose (GPU, default)
docker compose up -d
# Docker Compose (CPU only)
docker compose --profile cpu up -d
``` ```
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`. > **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
#### Start HTTP Server #### Start HTTP Server
Start the inference server with OpenAI-compatible HTTP API: Start the inference server with OpenAI and Anthropic-compatible HTTP API:
```bash ```bash
python -m scripts.tools.server --port 8000 --device cuda python -m scripts.tools.server --port 8000 --device cuda
@ -140,7 +153,7 @@ python -m scripts.tools.server --port 8000 --device cuda
Make requests: Make requests:
```bash ```bash
# Chat API (OpenAI compatible) # OpenAI-compatible
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -148,7 +161,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 512 "max_tokens": 512
}' }'
# Streaming response # OpenAI-compatible streaming
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -157,6 +170,27 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 500 "max_tokens": 500
}' }'
# Anthropic-compatible
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# Anthropic-compatible streaming with stop sequences
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["The end"]
}'
# Health check # Health check
curl http://localhost:8000/health curl http://localhost:8000/health
``` ```

View File

@ -53,6 +53,7 @@
- 📦 **轻量**: 依赖少,部署简单。 - 📦 **轻量**: 依赖少,部署简单。
- 🔬 **研究友好**: 模块化设计,便于实验新想法。 - 🔬 **研究友好**: 模块化设计,便于实验新想法。
- 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。 - 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API开箱即用。
### 快速开始 ### 快速开始
@ -73,44 +74,48 @@ pip install -e ".[dev]"
#### 训练模型 #### 训练模型
```bash ```bash
python scripts/tools/train.py \ python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
| `--random_seed` | 随机种子 | 3407 |
| `--num_workers` | 数据加载线程数 | 4 |
| `--window_size` | 最大输入序列长度 | auto |
| `--stride` | 序列步长 | auto |
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
| `--dpo_beta` | DPO beta | 0.1 |
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
| `--group_size` | GRPO 组大小 | 4 |
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
| `--start_batch` | 起始批次(用于断点续训) | 0 |
| `--nprocs` | GPU 数量 | 1 |
| `--device_type` | 设备类型 | cuda |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### 文本生成 #### 文本生成
```bash ```bash
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### 训练参数
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--num_workers` | 数据加载线程数 | 4 |
| `--nprocs` | GPU 数量 | 1 |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### Docker #### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境): 使用 Docker 构建和运行(推荐用于 GPU 环境):
@ -131,13 +136,19 @@ docker run --gpus all -p 8000:8000 astrai:latest \
# 挂载数据卷 # 挂载数据卷
docker run --gpus all -v /path/to/data:/data -it astrai:latest docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker ComposeGPU默认
docker compose up -d
# Docker Compose仅 CPU
docker compose --profile cpu up -d
``` ```
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False` > **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`
#### 启动 HTTP 服务 #### 启动 HTTP 服务
启动推理服务器,支持 OpenAI 兼容的 HTTP API 启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API
```bash ```bash
python -m scripts.tools.server --port 8000 --device cuda python -m scripts.tools.server --port 8000 --device cuda
@ -146,7 +157,7 @@ python -m scripts.tools.server --port 8000 --device cuda
发起请求: 发起请求:
```bash ```bash
# Chat APIOpenAI 兼容 # OpenAI 兼容
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -154,7 +165,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 512 "max_tokens": 512
}' }'
# 流式响应 # OpenAI 兼容流式
curl -X POST http://localhost:8000/v1/chat/completions \ curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
@ -163,6 +174,27 @@ curl -X POST http://localhost:8000/v1/chat/completions \
"max_tokens": 500 "max_tokens": 500
}' }'
# Anthropic 兼容
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "你是一个乐于助人的助手。",
"messages": [{"role": "user", "content": "你好"}],
"max_tokens": 512
}'
# Anthropic 兼容流式并设置停止序列
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "写个故事"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["结束"]
}'
# 健康检查 # 健康检查
curl http://localhost:8000/health curl http://localhost:8000/health
``` ```

View File

@ -262,25 +262,60 @@ curl -X POST http://localhost:8000/v1/chat/completions \
The server uses Server-Sent Events (SSE) with content type `text/event-stream`. The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
### Simple Generation Endpoint ### Health Check
For basic text generation without chat format:
### Anthropic-Compatible Endpoint
The server also provides an Anthropic-compatible endpoint at `/v1/messages`:
```bash ```bash
curl -X POST "http://localhost:8000/generate?query=Hello&max_len=1000" \ curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json"
```
Or with conversation history:
```bash
curl -X POST "http://localhost:8000/generate" \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"query": "What is AI?", "model": "astrai",
"history": [["Hello", "Hi there!"], ["How are you?", "I'm doing well"]], "system": "You are a helpful assistant.",
"temperature": 0.8, "messages": [{"role": "user", "content": "Hello, how are you?"}],
"max_len": 2048 "max_tokens": 2048
}'
```
Response:
```json
{
"id": "msg_abc123...",
"type": "message",
"role": "assistant",
"model": "astrai",
"content": [{"type": "text", "text": "Hello! I am doing well..."}],
"stop_reason": "end_turn",
"stop_sequence": null,
"usage": {"input_tokens": 20, "output_tokens": 15}
}
```
Streaming:
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Write a short poem"}],
"max_tokens": 500,
"stream": true
}'
```
Supports `stop_sequences` for early termination:
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stop_sequences": ["The end", "THE END"]
}' }'
``` ```

View File

@ -1,5 +1,5 @@
""" """
OpenAI-compatible chat completion server backed by continuous-batching inference. OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
""" """
import json import json
@ -51,6 +51,7 @@ class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage] messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1) max_tokens: Optional[int] = Field(default=2048, ge=1)
@ -61,6 +62,25 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
def configure_server( def configure_server(
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
@ -185,7 +205,7 @@ async def chat_completion(request: ChatCompletionRequest):
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
top_k=50, top_k=request.top_k,
) )
async def event_stream(): async def event_stream():
@ -237,7 +257,7 @@ async def chat_completion(request: ChatCompletionRequest):
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
top_k=50, top_k=request.top_k,
) )
async for token in agen: async for token in agen:
chunks.append(token) chunks.append(token)
@ -264,55 +284,183 @@ async def chat_completion(request: ChatCompletionRequest):
} }
@app.post("/generate") def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
async def generate( return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
query: str,
history: Optional[List[List[str]]] = None,
temperature: float = 0.8, def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
top_p: float = 0.95, for seq in stop_sequences:
top_k: int = 50, if seq and seq in text:
max_len: int = 2048, return seq
stream: bool = False, return None
):
"""Legacy non-OpenAI generation endpoint (kept for backward compat)."""
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.post("/v1/messages")
async def create_message(request: MessagesRequest):
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
engine = _get_engine() engine = _get_engine()
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
model = request.model
messages = [] chat_messages = _build_anthropic_messages(request.messages, request.system)
if history: prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
for h in history: prompt_tokens = len(engine.tokenizer.encode(prompt))
if len(h) >= 2:
messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) stop_sequences = request.stop_sequences or []
if stream: if request.stream:
agen = engine.generate_async( agen = engine.generate_async(
prompt=prompt, prompt=prompt,
max_tokens=max_len, max_tokens=request.max_tokens,
temperature=temperature, temperature=request.temperature,
top_p=top_p, top_p=request.top_p,
top_k=top_k, top_k=request.top_k,
) )
async def text_stream(): async def event_stream():
async for token in agen: yield _make_anthropic_sse(
yield token + "\n" "message_start",
{
"type": "message_start",
"message": {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [],
"usage": {"input_tokens": prompt_tokens},
},
},
)
return StreamingResponse(text_stream(), media_type="text/plain") yield _make_anthropic_sse(
else: "content_block_start",
chunks = [] {
for token in engine.generate( "type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
)
completion_tokens = 0
accumulated = ""
stopped_seq: Optional[str] = None
async for token in agen:
accumulated += token
completion_tokens += 1
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
text = accumulated[: accumulated.rfind(matched)]
stopped_seq = matched
if text:
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": text},
},
)
break
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
)
yield _make_anthropic_sse(
"content_block_stop",
{"type": "content_block_stop", "index": 0},
)
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
yield _make_anthropic_sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
"usage": {"output_tokens": completion_tokens},
},
)
yield _make_anthropic_sse(
"message_stop",
{"type": "message_stop"},
)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt, prompt=prompt,
stream=True, max_tokens=request.max_tokens,
max_tokens=max_len, temperature=request.temperature,
temperature=temperature, top_p=request.top_p,
top_p=top_p, top_k=request.top_k,
top_k=top_k, )
): stopped_seq: Optional[str] = None
accumulated = ""
async for token in agen:
chunks.append(token) chunks.append(token)
return {"response": "".join(chunks)} completion_tokens += 1
accumulated += token
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
stopped_seq = matched
break
content = "".join(chunks)
if stopped_seq:
idx = content.rfind(stopped_seq)
if idx != -1:
content = content[:idx]
return {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
"stop_sequence": stopped_seq,
"usage": {
"input_tokens": prompt_tokens,
"output_tokens": completion_tokens,
},
}
def run_server( def run_server(

View File

@ -265,7 +265,9 @@ class DPOStrategy(BaseStrategy):
class GRPOStrategy(BaseStrategy): class GRPOStrategy(BaseStrategy):
"""Group Relative Policy Optimization strategy. """Group Relative Policy Optimization strategy.
Implements GRPO with clipping and KL penalty. On-policy GRPO following DeepSeek-R1: the policy model is updated while
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
""" """
def __init__( def __init__(
@ -276,6 +278,7 @@ class GRPOStrategy(BaseStrategy):
kl_coef: float = 0.01, kl_coef: float = 0.01,
group_size: int = 4, group_size: int = 4,
reduction: str = "mean", reduction: str = "mean",
sync_interval: int = 200,
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
@ -284,8 +287,19 @@ class GRPOStrategy(BaseStrategy):
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size
self.reduction = reduction self.reduction = reduction
self.sync_interval = sync_interval
self._step = 0
def sync_ref_model(self):
"""Copy current model weights to ref model."""
ref_state = self.model.state_dict()
self.ref_model.load_state_dict(ref_state)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1
if self._step % self.sync_interval == 0:
self.sync_ref_model()
batch = move_to_device(batch, self.device) batch = move_to_device(batch, self.device)
prompts = batch["prompts"] prompts = batch["prompts"]
responses = batch["responses"] responses = batch["responses"]
@ -297,7 +311,6 @@ class GRPOStrategy(BaseStrategy):
masks_flat = masks.view(-1, response_len) masks_flat = masks.view(-1, response_len)
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
# Shape: (batch_size * group_size, seq_len + response_len)
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
@ -312,14 +325,13 @@ class GRPOStrategy(BaseStrategy):
) )
log_probs_ref = log_probs_ref.view(batch_size, group_size) log_probs_ref = log_probs_ref.view(batch_size, group_size)
# Compute advantages from rewards with normalization
eps = torch.finfo(log_probs_policy.dtype).eps eps = torch.finfo(log_probs_policy.dtype).eps
mean = rewards.mean(dim=-1, keepdim=True) mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + eps) advantages = (rewards - mean) / (std + eps)
# PPO-style clipped surrogate objective ratio = torch.exp(log_probs_policy - log_probs_ref)
ratio = torch.exp(0) # Off-policy: policy_model = old_model
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages

42
docker-compose.yml Normal file
View File

@ -0,0 +1,42 @@
services:
server:
build: .
image: astrai:latest
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cuda
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
restart: unless-stopped
server-cpu:
profiles: [cpu]
build: .
image: astrai:latest
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cpu
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 120s
restart: unless-stopped

View File

@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace:
"--train_type", "--train_type",
type=str, type=str,
required=True, required=True,
choices=["seq", "sft", "dpo"], choices=["seq", "sft", "dpo", "grpo"],
help="Train type.", help="Train type.",
) )
parser.add_argument( parser.add_argument(
@ -42,9 +42,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train." "--n_epoch", type=int, default=1, help="Number of epochs to train."
) )
parser.add_argument( parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
"--batch_size", type=int, default=1, help="Batch size for training."
)
parser.add_argument( parser.add_argument(
"--accumulation_steps", "--accumulation_steps",
type=int, type=int,
@ -106,6 +104,17 @@ def parse_args() -> argparse.Namespace:
"--stride", type=int, default=None, help="the step size of the input sequence." "--stride", type=int, default=None, help="the step size of the input sequence."
) )
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument(
"--on_policy",
action="store_true",
default=False,
help="Enable on-policy GRPO mode.",
)
parser.add_argument(
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
)
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument( parser.add_argument(
"--label_smoothing", "--label_smoothing",
type=float, type=float,
@ -125,6 +134,13 @@ def parse_args() -> argparse.Namespace:
default="checkpoint", default="checkpoint",
help="Directory to save checkpoints.", help="Directory to save checkpoints.",
) )
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument(
"--grpo_sync_interval",
type=int,
default=200,
help="GRPO ref model sync interval (steps).",
)
parser.add_argument( parser.add_argument(
"--start_epoch", type=int, default=0, help="Start epoch for training." "--start_epoch", type=int, default=0, help="Start epoch for training."
) )
@ -182,6 +198,10 @@ def train(
ckpt_interval: int, ckpt_interval: int,
ckpt_dir: str, ckpt_dir: str,
dpo_beta: float, dpo_beta: float,
grpo_clip_eps: float,
grpo_kl_coef: float,
group_size: int,
grpo_sync_interval: int,
adamw_beta1: float, adamw_beta1: float,
adamw_beta2: float, adamw_beta2: float,
adamw_weight_decay: float, adamw_weight_decay: float,
@ -195,7 +215,7 @@ def train(
nprocs: int, nprocs: int,
device_type: str, device_type: str,
): ):
assert train_type in ["seq", "sft", "dpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
# Load config # Load config
@ -216,7 +236,14 @@ def train(
state_dict = st.load_file(weights_path) state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing} strategy_kwargs = {
"dpo_beta": dpo_beta,
"label_smoothing": label_smoothing,
"clip_eps": grpo_clip_eps,
"kl_coef": grpo_kl_coef,
"group_size": group_size,
"sync_interval": grpo_sync_interval,
}
dataset = DatasetFactory.load( dataset = DatasetFactory.load(
train_type=train_type, train_type=train_type,

View File

@ -1,7 +1,5 @@
"""Unit tests for the inference HTTP server.""" """Unit tests for the inference HTTP server."""
from unittest.mock import MagicMock
import pytest import pytest
@ -24,52 +22,6 @@ def test_health_with_model(client, loaded_model):
assert data["model_loaded"] is True assert data["model_loaded"] is True
def test_generate_non_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=false should return JSON response."""
response = client.post(
"/generate",
params={
"query": "Hello",
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_len": 100,
"stream": False,
},
)
assert response.status_code == 200
data = response.json()
assert "response" in data
def test_generate_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=true should return plain text stream."""
async def async_gen():
yield "chunk1"
yield "chunk2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/generate",
params={
"query": "Hello",
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_len": 100,
"stream": True,
},
headers={"Accept": "text/plain"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "chunk1" in content
assert "chunk2" in content
def test_chat_completions_non_stream(client, loaded_model, monkeypatch): def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON.""" """POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
@ -125,17 +77,87 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch):
assert any("[DONE]" in line for line in lines) assert any("[DONE]" in line for line in lines)
def test_generate_with_history(client, loaded_model, monkeypatch): def test_messages_non_stream(client, loaded_model, monkeypatch):
"""POST /generate with history parameter.""" """POST /v1/messages with stream=false returns Anthropic-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/v1/messages",
params={ json={
"query": "Hi", "messages": [{"role": "user", "content": "Hello"}],
"history": [["user1", "assistant1"], ["user2", "assistant2"]], "temperature": 0.8,
"max_tokens": 100,
"stream": False, "stream": False,
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json()
assert data["type"] == "message"
assert data["role"] == "assistant"
assert len(data["content"]) == 1
assert data["content"][0]["type"] == "text"
assert "usage" in data
assert "input_tokens" in data["usage"]
def test_messages_stream(client, loaded_model, monkeypatch):
"""POST /v1/messages with stream=true returns Anthropic SSE stream."""
async def async_gen():
yield "cumulative1"
yield "cumulative2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/messages",
json={
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8,
"max_tokens": 100,
"stream": True,
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "message_start" in content
assert "content_block_start" in content
assert "content_block_delta" in content
assert "cumulative1" in content
assert "cumulative2" in content
assert "content_block_stop" in content
assert "message_delta" in content
assert "message_stop" in content
def test_messages_with_system(client, loaded_model, monkeypatch):
"""POST /v1/messages with system prompt."""
async def async_gen():
yield "Reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/messages",
json={
"messages": [{"role": "user", "content": "Hello"}],
"system": "You are a helpful assistant.",
"max_tokens": 100,
"stream": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["type"] == "message"
if __name__ == "__main__": if __name__ == "__main__":