Compare commits
No commits in common. "bc7c82977ee6725e98f909aa9260540fe811289d" and "9d96b0431d04585a7b2fc24d49e08936b03fda84" have entirely different histories.
bc7c82977e
...
9d96b0431d
|
|
@ -15,7 +15,6 @@
|
||||||
!/.gitattributes
|
!/.gitattributes
|
||||||
!/.dockerignore
|
!/.dockerignore
|
||||||
!/Dockerfile
|
!/Dockerfile
|
||||||
!/docker-compose.yml
|
|
||||||
!/assets/**
|
!/assets/**
|
||||||
!/CONTRIBUTING.md
|
!/CONTRIBUTING.md
|
||||||
!/LICENSE
|
!/LICENSE
|
||||||
|
|
|
||||||
100
README.md
100
README.md
|
|
@ -47,7 +47,6 @@
|
||||||
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
|
||||||
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
- 🔬 **Research‑Friendly**: Modular design, easy to experiment with new ideas.
|
||||||
- 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets.
|
- 🤗 **HuggingFace Integration**: Compatible with HuggingFace models and datasets.
|
||||||
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
|
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
|
|
@ -68,48 +67,44 @@ pip install -e ".[dev]"
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
python scripts/tools/train.py \
|
||||||
|
--train_type=seq \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/model \
|
||||||
|
--n_epoch=3 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--accumulation_steps=8 \
|
||||||
|
--max_lr=3e-4 \
|
||||||
|
--warmup_steps=2000 \
|
||||||
|
--ckpt_interval=5000 \
|
||||||
|
--ckpt_dir=./checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
|
||||||
| `--param_path` | Model / checkpoint path | required |
|
|
||||||
| `--n_epoch` | Training epochs | 1 |
|
|
||||||
| `--batch_size` | Batch size | 1 |
|
|
||||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
|
||||||
| `--warmup_steps` | LR warmup steps | 1000 |
|
|
||||||
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
|
||||||
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
|
||||||
| `--random_seed` | Random seed | 3407 |
|
|
||||||
| `--num_workers` | DataLoader workers | 4 |
|
|
||||||
| `--window_size` | Max input sequence length | auto |
|
|
||||||
| `--stride` | Sequence stride | auto |
|
|
||||||
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
|
|
||||||
| `--dpo_beta` | DPO beta | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
|
|
||||||
| `--group_size` | GRPO group size | 4 |
|
|
||||||
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
|
|
||||||
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
|
||||||
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
|
||||||
| `--start_epoch` | Start epoch (for resume) | 0 |
|
|
||||||
| `--start_batch` | Start batch (for resume) | 0 |
|
|
||||||
| `--nprocs` | Number of GPUs | 1 |
|
|
||||||
| `--device_type` | Device type | cuda |
|
|
||||||
|
|
||||||
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
|
||||||
|
|
||||||
#### Generate Text
|
#### Generate Text
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Training Parameters
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||||
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
|
| `--param_path` | Model / checkpoint path | required |
|
||||||
|
| `--n_epoch` | Training epochs | 1 |
|
||||||
|
| `--batch_size` | Batch size | 1 |
|
||||||
|
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||||
|
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
||||||
|
| `--warmup_steps` | LR warmup steps | 1000 |
|
||||||
|
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
||||||
|
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
||||||
|
| `--num_workers` | DataLoader workers | 4 |
|
||||||
|
| `--nprocs` | Number of GPUs | 1 |
|
||||||
|
|
||||||
|
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
Build and run with Docker (recommended for GPU environments):
|
Build and run with Docker (recommended for GPU environments):
|
||||||
|
|
@ -130,21 +125,13 @@ docker run --gpus all -p 8000:8000 astrai:latest \
|
||||||
|
|
||||||
# Run with volume mount for data
|
# Run with volume mount for data
|
||||||
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
||||||
|
|
||||||
# Docker Compose (GPU, default)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Docker Compose (CPU only)
|
|
||||||
docker compose --profile cpu up -d
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
||||||
|
|
||||||
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
|
|
||||||
|
|
||||||
#### Start HTTP Server
|
#### Start HTTP Server
|
||||||
|
|
||||||
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
|
Start the inference server with OpenAI-compatible HTTP API:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m scripts.tools.server --port 8000 --device cuda
|
python -m scripts.tools.server --port 8000 --device cuda
|
||||||
|
|
@ -153,7 +140,7 @@ python -m scripts.tools.server --port 8000 --device cuda
|
||||||
Make requests:
|
Make requests:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# OpenAI-compatible
|
# Chat API (OpenAI compatible)
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
|
|
@ -161,7 +148,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
"max_tokens": 512
|
"max_tokens": 512
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# OpenAI-compatible streaming
|
# Streaming response
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
|
|
@ -170,27 +157,6 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
"max_tokens": 500
|
"max_tokens": 500
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# Anthropic-compatible
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"system": "You are a helpful assistant.",
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"max_tokens": 512
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Anthropic-compatible streaming with stop sequences
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"messages": [{"role": "user", "content": "Write a story"}],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"stream": true,
|
|
||||||
"stop_sequences": ["The end"]
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
curl http://localhost:8000/health
|
curl http://localhost:8000/health
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,6 @@
|
||||||
- 📦 **轻量**: 依赖少,部署简单。
|
- 📦 **轻量**: 依赖少,部署简单。
|
||||||
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
|
||||||
- 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。
|
- 🤗 **HuggingFace 集成**: 兼容 HuggingFace 模型与数据集。
|
||||||
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API,开箱即用。
|
|
||||||
|
|
||||||
### 快速开始
|
### 快速开始
|
||||||
|
|
||||||
|
|
@ -74,48 +73,44 @@ pip install -e ".[dev]"
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
python scripts/tools/train.py \
|
||||||
|
--train_type=seq \
|
||||||
|
--data_root_path=/path/to/dataset \
|
||||||
|
--param_path=/path/to/model \
|
||||||
|
--n_epoch=3 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--accumulation_steps=8 \
|
||||||
|
--max_lr=3e-4 \
|
||||||
|
--warmup_steps=2000 \
|
||||||
|
--ckpt_interval=5000 \
|
||||||
|
--ckpt_dir=./checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
| 参数 | 说明 | 默认值 |
|
|
||||||
|------|------|--------|
|
|
||||||
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 |
|
|
||||||
| `--data_root_path` | 数据集根目录 | 必填 |
|
|
||||||
| `--param_path` | 模型参数或断点路径 | 必填 |
|
|
||||||
| `--n_epoch` | 训练轮数 | 1 |
|
|
||||||
| `--batch_size` | 批次大小 | 1 |
|
|
||||||
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
|
||||||
| `--warmup_steps` | 预热步数 | 1000 |
|
|
||||||
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
|
||||||
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
|
||||||
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
|
|
||||||
| `--random_seed` | 随机种子 | 3407 |
|
|
||||||
| `--num_workers` | 数据加载线程数 | 4 |
|
|
||||||
| `--window_size` | 最大输入序列长度 | auto |
|
|
||||||
| `--stride` | 序列步长 | auto |
|
|
||||||
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
|
|
||||||
| `--dpo_beta` | DPO beta | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
|
|
||||||
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
|
|
||||||
| `--group_size` | GRPO 组大小 | 4 |
|
|
||||||
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
|
|
||||||
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
|
||||||
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
|
||||||
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
|
|
||||||
| `--start_batch` | 起始批次(用于断点续训) | 0 |
|
|
||||||
| `--nprocs` | GPU 数量 | 1 |
|
|
||||||
| `--device_type` | 设备类型 | cuda |
|
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md#training-parameters)。
|
|
||||||
|
|
||||||
#### 文本生成
|
#### 文本生成
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 训练参数
|
||||||
|
|
||||||
|
| 参数 | 说明 | 默认值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 |
|
||||||
|
| `--data_root_path` | 数据集根目录 | 必填 |
|
||||||
|
| `--param_path` | 模型参数或断点路径 | 必填 |
|
||||||
|
| `--n_epoch` | 训练轮数 | 1 |
|
||||||
|
| `--batch_size` | 批次大小 | 1 |
|
||||||
|
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
||||||
|
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
||||||
|
| `--warmup_steps` | 预热步数 | 1000 |
|
||||||
|
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
||||||
|
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
||||||
|
| `--num_workers` | 数据加载线程数 | 4 |
|
||||||
|
| `--nprocs` | GPU 数量 | 1 |
|
||||||
|
|
||||||
|
完整参数列表见[参数说明](./params.md#training-parameters)。
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
||||||
|
|
@ -136,19 +131,13 @@ docker run --gpus all -p 8000:8000 astrai:latest \
|
||||||
|
|
||||||
# 挂载数据卷
|
# 挂载数据卷
|
||||||
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
docker run --gpus all -v /path/to/data:/data -it astrai:latest
|
||||||
|
|
||||||
# Docker Compose(GPU,默认)
|
|
||||||
docker compose up -d
|
|
||||||
|
|
||||||
# Docker Compose(仅 CPU)
|
|
||||||
docker compose --profile cpu up -d
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`。
|
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`。
|
||||||
|
|
||||||
#### 启动 HTTP 服务
|
#### 启动 HTTP 服务
|
||||||
|
|
||||||
启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API:
|
启动推理服务器,支持 OpenAI 兼容的 HTTP API:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m scripts.tools.server --port 8000 --device cuda
|
python -m scripts.tools.server --port 8000 --device cuda
|
||||||
|
|
@ -157,7 +146,7 @@ python -m scripts.tools.server --port 8000 --device cuda
|
||||||
发起请求:
|
发起请求:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# OpenAI 兼容
|
# Chat API(OpenAI 兼容)
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
|
|
@ -165,7 +154,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
"max_tokens": 512
|
"max_tokens": 512
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# OpenAI 兼容流式
|
# 流式响应
|
||||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
|
|
@ -174,27 +163,6 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
"max_tokens": 500
|
"max_tokens": 500
|
||||||
}'
|
}'
|
||||||
|
|
||||||
# Anthropic 兼容
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"system": "你是一个乐于助人的助手。",
|
|
||||||
"messages": [{"role": "user", "content": "你好"}],
|
|
||||||
"max_tokens": 512
|
|
||||||
}'
|
|
||||||
|
|
||||||
# Anthropic 兼容流式并设置停止序列
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"messages": [{"role": "user", "content": "写个故事"}],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"stream": true,
|
|
||||||
"stop_sequences": ["结束"]
|
|
||||||
}'
|
|
||||||
|
|
||||||
# 健康检查
|
# 健康检查
|
||||||
curl http://localhost:8000/health
|
curl http://localhost:8000/health
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -262,60 +262,25 @@ curl -X POST http://localhost:8000/v1/chat/completions \
|
||||||
|
|
||||||
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
|
||||||
|
|
||||||
### Health Check
|
### Simple Generation Endpoint
|
||||||
|
|
||||||
|
For basic text generation without chat format:
|
||||||
### Anthropic-Compatible Endpoint
|
|
||||||
|
|
||||||
The server also provides an Anthropic-compatible endpoint at `/v1/messages`:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
curl -X POST "http://localhost:8000/generate?query=Hello&max_len=1000" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json"
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"system": "You are a helpful assistant.",
|
|
||||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
|
||||||
"max_tokens": 2048
|
|
||||||
}'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Response:
|
Or with conversation history:
|
||||||
```json
|
|
||||||
{
|
|
||||||
"id": "msg_abc123...",
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": "astrai",
|
|
||||||
"content": [{"type": "text", "text": "Hello! I am doing well..."}],
|
|
||||||
"stop_reason": "end_turn",
|
|
||||||
"stop_sequence": null,
|
|
||||||
"usage": {"input_tokens": 20, "output_tokens": 15}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Streaming:
|
|
||||||
```bash
|
```bash
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
curl -X POST "http://localhost:8000/generate" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "astrai",
|
"query": "What is AI?",
|
||||||
"system": "You are a helpful assistant.",
|
"history": [["Hello", "Hi there!"], ["How are you?", "I'm doing well"]],
|
||||||
"messages": [{"role": "user", "content": "Write a short poem"}],
|
"temperature": 0.8,
|
||||||
"max_tokens": 500,
|
"max_len": 2048
|
||||||
"stream": true
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
Supports `stop_sequences` for early termination:
|
|
||||||
```bash
|
|
||||||
curl -X POST http://localhost:8000/v1/messages \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{
|
|
||||||
"model": "astrai",
|
|
||||||
"messages": [{"role": "user", "content": "Write a story"}],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"stop_sequences": ["The end", "THE END"]
|
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
OpenAI-compatible chat completion server backed by continuous-batching inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -51,7 +51,6 @@ class ChatCompletionRequest(BaseModel):
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
top_k: Optional[int] = Field(default=50, ge=1)
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||||
|
|
@ -62,25 +61,6 @@ class ChatCompletionRequest(BaseModel):
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: Union[str, List[Dict[str, Any]]]
|
|
||||||
|
|
||||||
|
|
||||||
class MessagesRequest(BaseModel):
|
|
||||||
"""Anthropic Messages API request body."""
|
|
||||||
|
|
||||||
model: str = "astrai"
|
|
||||||
max_tokens: int = Field(default=1024, ge=1)
|
|
||||||
messages: List[AnthropicMessage]
|
|
||||||
system: Optional[str] = None
|
|
||||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
|
||||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
||||||
top_k: Optional[int] = Field(default=50, ge=1)
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
stop_sequences: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
def configure_server(
|
def configure_server(
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
|
@ -205,7 +185,7 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
|
|
@ -257,7 +237,7 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=50,
|
||||||
)
|
)
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
|
|
@ -284,183 +264,55 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
|
@app.post("/generate")
|
||||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
async def generate(
|
||||||
|
query: str,
|
||||||
|
history: Optional[List[List[str]]] = None,
|
||||||
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
|
temperature: float = 0.8,
|
||||||
for seq in stop_sequences:
|
top_p: float = 0.95,
|
||||||
if seq and seq in text:
|
top_k: int = 50,
|
||||||
return seq
|
max_len: int = 2048,
|
||||||
return None
|
stream: bool = False,
|
||||||
|
):
|
||||||
|
"""Legacy non-OpenAI generation endpoint (kept for backward compat)."""
|
||||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_anthropic_messages(
|
|
||||||
messages: List[AnthropicMessage], system: Optional[str]
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
result: List[Dict[str, str]] = []
|
|
||||||
if system:
|
|
||||||
result.append({"role": "system", "content": system})
|
|
||||||
for m in messages:
|
|
||||||
content = _extract_text_content(m.content)
|
|
||||||
if content:
|
|
||||||
result.append({"role": m.role, "content": content})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
|
||||||
async def create_message(request: MessagesRequest):
|
|
||||||
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
|
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
|
|
||||||
model = request.model
|
|
||||||
|
|
||||||
chat_messages = _build_anthropic_messages(request.messages, request.system)
|
messages = []
|
||||||
prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
|
if history:
|
||||||
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
for h in history:
|
||||||
|
if len(h) >= 2:
|
||||||
|
messages.append({"role": "user", "content": h[0]})
|
||||||
|
messages.append({"role": "assistant", "content": h[1]})
|
||||||
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
stop_sequences = request.stop_sequences or []
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
if request.stream:
|
if stream:
|
||||||
agen = engine.generate_async(
|
agen = engine.generate_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=max_len,
|
||||||
temperature=request.temperature,
|
temperature=temperature,
|
||||||
top_p=request.top_p,
|
top_p=top_p,
|
||||||
top_k=request.top_k,
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def event_stream():
|
async def text_stream():
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"message_start",
|
|
||||||
{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": [],
|
|
||||||
"usage": {"input_tokens": prompt_tokens},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"content_block_start",
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_tokens = 0
|
|
||||||
accumulated = ""
|
|
||||||
stopped_seq: Optional[str] = None
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
accumulated += token
|
yield token + "\n"
|
||||||
completion_tokens += 1
|
|
||||||
|
|
||||||
matched = _check_stop_sequence(accumulated, stop_sequences)
|
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||||
if matched:
|
else:
|
||||||
text = accumulated[: accumulated.rfind(matched)]
|
chunks = []
|
||||||
stopped_seq = matched
|
for token in engine.generate(
|
||||||
if text:
|
prompt=prompt,
|
||||||
yield _make_anthropic_sse(
|
stream=True,
|
||||||
"content_block_delta",
|
max_tokens=max_len,
|
||||||
{
|
temperature=temperature,
|
||||||
"type": "content_block_delta",
|
top_p=top_p,
|
||||||
"index": 0,
|
top_k=top_k,
|
||||||
"delta": {"type": "text_delta", "text": text},
|
):
|
||||||
},
|
chunks.append(token)
|
||||||
)
|
return {"response": "".join(chunks)}
|
||||||
break
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"content_block_delta",
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": token},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"content_block_stop",
|
|
||||||
{"type": "content_block_stop", "index": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"message_delta",
|
|
||||||
{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
|
|
||||||
"usage": {"output_tokens": completion_tokens},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"message_stop",
|
|
||||||
{"type": "message_stop"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_stream(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_tokens = 0
|
|
||||||
chunks: List[str] = []
|
|
||||||
agen = engine.generate_async(
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
)
|
|
||||||
stopped_seq: Optional[str] = None
|
|
||||||
accumulated = ""
|
|
||||||
async for token in agen:
|
|
||||||
chunks.append(token)
|
|
||||||
completion_tokens += 1
|
|
||||||
accumulated += token
|
|
||||||
matched = _check_stop_sequence(accumulated, stop_sequences)
|
|
||||||
if matched:
|
|
||||||
stopped_seq = matched
|
|
||||||
break
|
|
||||||
|
|
||||||
content = "".join(chunks)
|
|
||||||
if stopped_seq:
|
|
||||||
idx = content.rfind(stopped_seq)
|
|
||||||
if idx != -1:
|
|
||||||
content = content[:idx]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
|
|
||||||
"stop_sequence": stopped_seq,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": prompt_tokens,
|
|
||||||
"output_tokens": completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
|
|
|
||||||
|
|
@ -265,9 +265,7 @@ class DPOStrategy(BaseStrategy):
|
||||||
class GRPOStrategy(BaseStrategy):
|
class GRPOStrategy(BaseStrategy):
|
||||||
"""Group Relative Policy Optimization strategy.
|
"""Group Relative Policy Optimization strategy.
|
||||||
|
|
||||||
On-policy GRPO following DeepSeek-R1: the policy model is updated while
|
Implements GRPO with clipping and KL penalty.
|
||||||
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
|
|
||||||
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -278,7 +276,6 @@ class GRPOStrategy(BaseStrategy):
|
||||||
kl_coef: float = 0.01,
|
kl_coef: float = 0.01,
|
||||||
group_size: int = 4,
|
group_size: int = 4,
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
sync_interval: int = 200,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
|
|
@ -287,19 +284,8 @@ class GRPOStrategy(BaseStrategy):
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.sync_interval = sync_interval
|
|
||||||
self._step = 0
|
|
||||||
|
|
||||||
def sync_ref_model(self):
|
|
||||||
"""Copy current model weights to ref model."""
|
|
||||||
ref_state = self.model.state_dict()
|
|
||||||
self.ref_model.load_state_dict(ref_state)
|
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
self._step += 1
|
|
||||||
if self._step % self.sync_interval == 0:
|
|
||||||
self.sync_ref_model()
|
|
||||||
|
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
prompts = batch["prompts"]
|
prompts = batch["prompts"]
|
||||||
responses = batch["responses"]
|
responses = batch["responses"]
|
||||||
|
|
@ -311,6 +297,7 @@ class GRPOStrategy(BaseStrategy):
|
||||||
masks_flat = masks.view(-1, response_len)
|
masks_flat = masks.view(-1, response_len)
|
||||||
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
||||||
|
|
||||||
|
# Shape: (batch_size * group_size, seq_len + response_len)
|
||||||
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||||
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||||
|
|
||||||
|
|
@ -325,13 +312,14 @@ class GRPOStrategy(BaseStrategy):
|
||||||
)
|
)
|
||||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||||
|
|
||||||
|
# Compute advantages from rewards with normalization
|
||||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||||
mean = rewards.mean(dim=-1, keepdim=True)
|
mean = rewards.mean(dim=-1, keepdim=True)
|
||||||
std = rewards.std(dim=-1, keepdim=True)
|
std = rewards.std(dim=-1, keepdim=True)
|
||||||
advantages = (rewards - mean) / (std + eps)
|
advantages = (rewards - mean) / (std + eps)
|
||||||
|
|
||||||
ratio = torch.exp(log_probs_policy - log_probs_ref)
|
# PPO-style clipped surrogate objective
|
||||||
|
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
services:
|
|
||||||
server:
|
|
||||||
build: .
|
|
||||||
image: astrai:latest
|
|
||||||
ports:
|
|
||||||
- "8000:8000"
|
|
||||||
volumes:
|
|
||||||
- ./params:/app/params:ro
|
|
||||||
- ./checkpoints:/app/checkpoints
|
|
||||||
command: python -m scripts.tools.server --port 8000 --device cuda
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 60s
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
server-cpu:
|
|
||||||
profiles: [cpu]
|
|
||||||
build: .
|
|
||||||
image: astrai:latest
|
|
||||||
ports:
|
|
||||||
- "8000:8000"
|
|
||||||
volumes:
|
|
||||||
- ./params:/app/params:ro
|
|
||||||
- ./checkpoints:/app/checkpoints
|
|
||||||
command: python -m scripts.tools.server --port 8000 --device cpu
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 120s
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--train_type",
|
"--train_type",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
choices=["seq", "sft", "dpo", "grpo"],
|
choices=["seq", "sft", "dpo"],
|
||||||
help="Train type.",
|
help="Train type.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -42,7 +42,9 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
)
|
)
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Batch size for training."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accumulation_steps",
|
"--accumulation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -104,17 +106,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--stride", type=int, default=None, help="the step size of the input sequence."
|
"--stride", type=int, default=None, help="the step size of the input sequence."
|
||||||
)
|
)
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--on_policy",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Enable on-policy GRPO mode.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
|
|
||||||
)
|
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_smoothing",
|
"--label_smoothing",
|
||||||
type=float,
|
type=float,
|
||||||
|
|
@ -134,13 +125,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
default="checkpoint",
|
default="checkpoint",
|
||||||
help="Directory to save checkpoints.",
|
help="Directory to save checkpoints.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--grpo_sync_interval",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="GRPO ref model sync interval (steps).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--start_epoch", type=int, default=0, help="Start epoch for training."
|
"--start_epoch", type=int, default=0, help="Start epoch for training."
|
||||||
)
|
)
|
||||||
|
|
@ -198,10 +182,6 @@ def train(
|
||||||
ckpt_interval: int,
|
ckpt_interval: int,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
grpo_clip_eps: float,
|
|
||||||
grpo_kl_coef: float,
|
|
||||||
group_size: int,
|
|
||||||
grpo_sync_interval: int,
|
|
||||||
adamw_beta1: float,
|
adamw_beta1: float,
|
||||||
adamw_beta2: float,
|
adamw_beta2: float,
|
||||||
adamw_weight_decay: float,
|
adamw_weight_decay: float,
|
||||||
|
|
@ -215,7 +195,7 @@ def train(
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
assert train_type in ["seq", "sft", "dpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
|
|
@ -236,14 +216,7 @@ def train(
|
||||||
state_dict = st.load_file(weights_path)
|
state_dict = st.load_file(weights_path)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
||||||
"dpo_beta": dpo_beta,
|
|
||||||
"label_smoothing": label_smoothing,
|
|
||||||
"clip_eps": grpo_clip_eps,
|
|
||||||
"kl_coef": grpo_kl_coef,
|
|
||||||
"group_size": group_size,
|
|
||||||
"sync_interval": grpo_sync_interval,
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
"""Unit tests for the inference HTTP server."""
|
"""Unit tests for the inference HTTP server."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,6 +24,52 @@ def test_health_with_model(client, loaded_model):
|
||||||
assert data["model_loaded"] is True
|
assert data["model_loaded"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_non_stream(client, loaded_model, monkeypatch):
|
||||||
|
"""POST /generate with stream=false should return JSON response."""
|
||||||
|
response = client.post(
|
||||||
|
"/generate",
|
||||||
|
params={
|
||||||
|
"query": "Hello",
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 50,
|
||||||
|
"max_len": 100,
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "response" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_stream(client, loaded_model, monkeypatch):
|
||||||
|
"""POST /generate with stream=true should return plain text stream."""
|
||||||
|
|
||||||
|
async def async_gen():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
|
||||||
|
mock_engine = loaded_model
|
||||||
|
mock_engine.generate_async.return_value = async_gen()
|
||||||
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
|
response = client.post(
|
||||||
|
"/generate",
|
||||||
|
params={
|
||||||
|
"query": "Hello",
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 50,
|
||||||
|
"max_len": 100,
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
headers={"Accept": "text/plain"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
content = response.content.decode("utf-8")
|
||||||
|
assert "chunk1" in content
|
||||||
|
assert "chunk2" in content
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
|
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
|
||||||
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
|
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
|
||||||
|
|
||||||
|
|
@ -77,87 +125,17 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch):
|
||||||
assert any("[DONE]" in line for line in lines)
|
assert any("[DONE]" in line for line in lines)
|
||||||
|
|
||||||
|
|
||||||
def test_messages_non_stream(client, loaded_model, monkeypatch):
|
def test_generate_with_history(client, loaded_model, monkeypatch):
|
||||||
"""POST /v1/messages with stream=false returns Anthropic-style JSON."""
|
"""POST /generate with history parameter."""
|
||||||
|
|
||||||
async def async_gen():
|
|
||||||
yield "Assistant reply"
|
|
||||||
|
|
||||||
mock_engine = loaded_model
|
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/generate",
|
||||||
json={
|
params={
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
"query": "Hi",
|
||||||
"temperature": 0.8,
|
"history": [["user1", "assistant1"], ["user2", "assistant2"]],
|
||||||
"max_tokens": 100,
|
|
||||||
"stream": False,
|
"stream": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
|
||||||
assert data["type"] == "message"
|
|
||||||
assert data["role"] == "assistant"
|
|
||||||
assert len(data["content"]) == 1
|
|
||||||
assert data["content"][0]["type"] == "text"
|
|
||||||
assert "usage" in data
|
|
||||||
assert "input_tokens" in data["usage"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_messages_stream(client, loaded_model, monkeypatch):
|
|
||||||
"""POST /v1/messages with stream=true returns Anthropic SSE stream."""
|
|
||||||
|
|
||||||
async def async_gen():
|
|
||||||
yield "cumulative1"
|
|
||||||
yield "cumulative2"
|
|
||||||
|
|
||||||
mock_engine = loaded_model
|
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
|
||||||
"/v1/messages",
|
|
||||||
json={
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"temperature": 0.8,
|
|
||||||
"max_tokens": 100,
|
|
||||||
"stream": True,
|
|
||||||
},
|
|
||||||
headers={"Accept": "text/event-stream"},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.content.decode("utf-8")
|
|
||||||
assert "message_start" in content
|
|
||||||
assert "content_block_start" in content
|
|
||||||
assert "content_block_delta" in content
|
|
||||||
assert "cumulative1" in content
|
|
||||||
assert "cumulative2" in content
|
|
||||||
assert "content_block_stop" in content
|
|
||||||
assert "message_delta" in content
|
|
||||||
assert "message_stop" in content
|
|
||||||
|
|
||||||
|
|
||||||
def test_messages_with_system(client, loaded_model, monkeypatch):
|
|
||||||
"""POST /v1/messages with system prompt."""
|
|
||||||
|
|
||||||
async def async_gen():
|
|
||||||
yield "Reply"
|
|
||||||
|
|
||||||
mock_engine = loaded_model
|
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
|
||||||
"/v1/messages",
|
|
||||||
json={
|
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
|
||||||
"system": "You are a helpful assistant.",
|
|
||||||
"max_tokens": 100,
|
|
||||||
"stream": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["type"] == "message"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue