chore: 版本号升至 1.3.5

This commit is contained in:
ViperEkura 2026-05-15 18:15:59 +08:00
parent 9096e413c3
commit 19532440b4
5 changed files with 44 additions and 34 deletions

View File

@ -65,6 +65,16 @@ For development dependencies:
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### Download Pre-trained Model
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
```bash
python scripts/demo/download.py
```
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
#### Train a Model #### Train a Model
```bash ```bash

View File

@ -71,6 +71,16 @@ pip install -e .
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### 下载预训练模型
下载预训练模型权重1B 双语检查点)到 `params/` 目录:
```bash
python scripts/demo/download.py
```
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`
#### 训练模型 #### 训练模型
```bash ```bash

View File

@ -88,7 +88,7 @@ flowchart LR
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm - **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention) - **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection - **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
- **`RotaryEmbedding`**: RoPE cos/sin cache - **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
- **`RMSNorm`**: Layer normalization - **`RMSNorm`**: Layer normalization
### 4. Training Module ### 4. Training Module
@ -104,22 +104,23 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers
``` ```
on_train_begin on_train_begin
on_epoch_begin on_epoch_begin
for each batch: for each accumulation window of batches: ← step phase
if iteration % accumulation_steps == 0: ← step phase on_step_begin
on_step_begin → optimizer.step() → zero_grad → on_step_end for each batch in window: ← batch phase
← batch phase
on_batch_begin → strategy(batch) → loss → backward → on_batch_end on_batch_begin → strategy(batch) → loss → backward → on_batch_end
iteration += 1 iteration += 1
on_step_end
optimizer.step() → zero_grad
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
Key points: Key points:
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches) - `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
- `on_batch_*` wraps loss computation (fires every batch) - `on_batch_*` fires every batch, wrapping loss computation
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch - `GradientClippingCallback` fires on `on_step_end`
- `GradientClippingCallback` fires on `on_step_begin` - LR scheduler steps inline (no `SchedulerCallback` class)
#### 4.3 Strategy (`strategy.py`) #### 4.3 Strategy (`strategy.py`)
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing - **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
@ -136,8 +137,7 @@ Key points:
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations - **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
- **`ProgressBarCallback`**: tqdm progress display - **`ProgressBarCallback`**: tqdm progress display
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/` - **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin` - **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
### 5. Inference Module ### 5. Inference Module

View File

@ -91,8 +91,8 @@ classDiagram
} }
class BaseStorage { class BaseStorage {
+Dict segments +MultiSegmentFetcher _fetcher
+List keys +keys (property)
+load(load_path, tokenizer) +load(load_path, tokenizer)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
@ -145,7 +145,7 @@ classDiagram
+ModelConfig config +ModelConfig config
+Registry _registry +Registry _registry
+register(model_type) decorator +register(model_type) decorator
+get_model_class(model_type) Type +get_component_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module +from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory) +save_pretrained(save_directory)
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
@ -214,7 +214,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, position_ids=None) Tuple[Tensor, Tensor] +forward(x, position_ids=None) Tensor
} }
class Embedding { class Embedding {
@ -225,13 +225,10 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+List[int] stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int] +encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]] +apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
@ -325,6 +322,8 @@ classDiagram
+float clip_eps +float clip_eps
+float kl_coef +float kl_coef
+int group_size +int group_size
+str reduction
+int sync_interval
+compute_loss(batch) Tensor +compute_loss(batch) Tensor
} }
@ -369,11 +368,6 @@ classDiagram
+on_step_begin(context) +on_step_begin(context)
} }
class SchedulerCallback {
+on_train_begin(context)
+on_batch_end(context)
}
class CheckpointCallback { class CheckpointCallback {
+str save_dir +str save_dir
+int interval +int interval
@ -409,8 +403,6 @@ classDiagram
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator +generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
@ -421,13 +413,12 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache page_cache +KVCache _page_cache
+int max_batch_size +int max_batch_size
+int max_seq_len +int max_seq_len
+int max_prompt_len +int max_prompt_len
+int page_size +int page_size
+List waiting_queue +TaskManager _task_mgr
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
@ -568,7 +559,7 @@ classDiagram
} }
class GenerateResult { class GenerateResult {
+List[str] tokens +List[Tuple[int, str]] tokens
+List[str] results +List[str] results
+List[bool] _done +List[bool] _done
+append(token, idx) +append(token, idx)
@ -643,7 +634,6 @@ classDiagram
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- SchedulerCallback
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback

View File

@ -1,4 +1,4 @@
__version__ = "1.3.4" __version__ = "1.3.5"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (