chore: 版本号升至 1.3.5
This commit is contained in:
parent
9096e413c3
commit
19532440b4
10
README.md
10
README.md
|
|
@ -65,6 +65,16 @@ For development dependencies:
|
|||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
#### Download Pre-trained Model
|
||||
|
||||
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
|
||||
|
||||
```bash
|
||||
python scripts/demo/download.py
|
||||
```
|
||||
|
||||
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
|
||||
|
||||
#### Train a Model
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -71,6 +71,16 @@ pip install -e .
|
|||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
#### 下载预训练模型
|
||||
|
||||
下载预训练模型权重(1B 双语检查点)到 `params/` 目录:
|
||||
|
||||
```bash
|
||||
python scripts/demo/download.py
|
||||
```
|
||||
|
||||
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。
|
||||
|
||||
#### 训练模型
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ flowchart LR
|
|||
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
||||
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
||||
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
||||
- **`RotaryEmbedding`**: RoPE cos/sin cache
|
||||
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
|
||||
- **`RMSNorm`**: Layer normalization
|
||||
|
||||
### 4. Training Module
|
||||
|
|
@ -104,22 +104,23 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers
|
|||
```
|
||||
on_train_begin
|
||||
on_epoch_begin
|
||||
for each batch:
|
||||
if iteration % accumulation_steps == 0: ← step phase
|
||||
on_step_begin → optimizer.step() → zero_grad → on_step_end
|
||||
← batch phase
|
||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||
iteration += 1
|
||||
for each accumulation window of batches: ← step phase
|
||||
on_step_begin
|
||||
for each batch in window: ← batch phase
|
||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||
iteration += 1
|
||||
on_step_end
|
||||
optimizer.step() → zero_grad
|
||||
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
||||
Key points:
|
||||
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches)
|
||||
- `on_batch_*` wraps loss computation (fires every batch)
|
||||
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch
|
||||
- `GradientClippingCallback` fires on `on_step_begin`
|
||||
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
|
||||
- `on_batch_*` fires every batch, wrapping loss computation
|
||||
- `GradientClippingCallback` fires on `on_step_end`
|
||||
- LR scheduler steps inline (no `SchedulerCallback` class)
|
||||
|
||||
#### 4.3 Strategy (`strategy.py`)
|
||||
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
||||
|
|
@ -136,8 +137,7 @@ Key points:
|
|||
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
||||
- **`ProgressBarCallback`**: tqdm progress display
|
||||
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin`
|
||||
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
|
||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
|
||||
|
||||
### 5. Inference Module
|
||||
|
||||
|
|
|
|||
|
|
@ -91,8 +91,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class BaseStorage {
|
||||
+Dict segments
|
||||
+List keys
|
||||
+MultiSegmentFetcher _fetcher
|
||||
+keys (property)
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys)
|
||||
+__len__()
|
||||
|
|
@ -145,7 +145,7 @@ classDiagram
|
|||
+ModelConfig config
|
||||
+Registry _registry
|
||||
+register(model_type) decorator
|
||||
+get_model_class(model_type) Type
|
||||
+get_component_class(model_type) Type
|
||||
+from_pretrained(path, disable_random_init) nn.Module
|
||||
+save_pretrained(save_directory)
|
||||
+to(*args, **kwargs) Self
|
||||
|
|
@ -214,7 +214,7 @@ classDiagram
|
|||
+int dim
|
||||
+int max_len
|
||||
+float base
|
||||
+forward(x, position_ids=None) Tuple[Tensor, Tensor]
|
||||
+forward(x, position_ids=None) Tensor
|
||||
}
|
||||
|
||||
class Embedding {
|
||||
|
|
@ -225,13 +225,10 @@ classDiagram
|
|||
|
||||
namespace tokenize {
|
||||
class AutoTokenizer {
|
||||
+List[int] stop_ids
|
||||
+int bos_id
|
||||
+int eos_id
|
||||
+int pad_id
|
||||
+vocab_size int
|
||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||
+set_chat_template(template)
|
||||
+load(path)
|
||||
|
|
@ -325,6 +322,8 @@ classDiagram
|
|||
+float clip_eps
|
||||
+float kl_coef
|
||||
+int group_size
|
||||
+str reduction
|
||||
+int sync_interval
|
||||
+compute_loss(batch) Tensor
|
||||
}
|
||||
|
||||
|
|
@ -369,11 +368,6 @@ classDiagram
|
|||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class SchedulerCallback {
|
||||
+on_train_begin(context)
|
||||
+on_batch_end(context)
|
||||
}
|
||||
|
||||
class CheckpointCallback {
|
||||
+str save_dir
|
||||
+int interval
|
||||
|
|
@ -409,8 +403,6 @@ classDiagram
|
|||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+InferenceScheduler scheduler
|
||||
+int max_batch_size
|
||||
+Optional int max_seq_len
|
||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
||||
|
|
@ -421,13 +413,12 @@ classDiagram
|
|||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+KVCache page_cache
|
||||
+KVCache _page_cache
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
+int page_size
|
||||
+List waiting_queue
|
||||
+List active_tasks
|
||||
+TaskManager _task_mgr
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
+remove_task(task_id)
|
||||
+start()
|
||||
|
|
@ -568,7 +559,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class GenerateResult {
|
||||
+List[str] tokens
|
||||
+List[Tuple[int, str]] tokens
|
||||
+List[str] results
|
||||
+List[bool] _done
|
||||
+append(token, idx)
|
||||
|
|
@ -643,7 +634,6 @@ classDiagram
|
|||
BaseScheduler <|-- SGDRScheduler
|
||||
CallbackFactory ..> TrainCallback : creates
|
||||
TrainCallback <|-- GradientClippingCallback
|
||||
TrainCallback <|-- SchedulerCallback
|
||||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
__version__ = "1.3.4"
|
||||
__version__ = "1.3.5"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from astrai.config import (
|
||||
|
|
|
|||
Loading…
Reference in New Issue