chore: 版本号升至 1.3.5
This commit is contained in:
parent
9096e413c3
commit
19532440b4
10
README.md
10
README.md
|
|
@ -65,6 +65,16 @@ For development dependencies:
|
||||||
pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Download Pre-trained Model
|
||||||
|
|
||||||
|
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/demo/download.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
|
||||||
|
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,16 @@ pip install -e .
|
||||||
pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 下载预训练模型
|
||||||
|
|
||||||
|
下载预训练模型权重(1B 双语检查点)到 `params/` 目录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/demo/download.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。
|
||||||
|
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ flowchart LR
|
||||||
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
||||||
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
||||||
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
||||||
- **`RotaryEmbedding`**: RoPE cos/sin cache
|
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
|
||||||
- **`RMSNorm`**: Layer normalization
|
- **`RMSNorm`**: Layer normalization
|
||||||
|
|
||||||
### 4. Training Module
|
### 4. Training Module
|
||||||
|
|
@ -104,22 +104,23 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers
|
||||||
```
|
```
|
||||||
on_train_begin
|
on_train_begin
|
||||||
on_epoch_begin
|
on_epoch_begin
|
||||||
for each batch:
|
for each accumulation window of batches: ← step phase
|
||||||
if iteration % accumulation_steps == 0: ← step phase
|
on_step_begin
|
||||||
on_step_begin → optimizer.step() → zero_grad → on_step_end
|
for each batch in window: ← batch phase
|
||||||
← batch phase
|
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
iteration += 1
|
||||||
iteration += 1
|
on_step_end
|
||||||
|
optimizer.step() → zero_grad
|
||||||
|
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
```
|
```
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches)
|
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
|
||||||
- `on_batch_*` wraps loss computation (fires every batch)
|
- `on_batch_*` fires every batch, wrapping loss computation
|
||||||
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch
|
- `GradientClippingCallback` fires on `on_step_end`
|
||||||
- `GradientClippingCallback` fires on `on_step_begin`
|
- LR scheduler steps inline (no `SchedulerCallback` class)
|
||||||
|
|
||||||
#### 4.3 Strategy (`strategy.py`)
|
#### 4.3 Strategy (`strategy.py`)
|
||||||
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
||||||
|
|
@ -136,8 +137,7 @@ Key points:
|
||||||
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
||||||
- **`ProgressBarCallback`**: tqdm progress display
|
- **`ProgressBarCallback`**: tqdm progress display
|
||||||
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
||||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin`
|
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
|
||||||
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
|
|
||||||
|
|
||||||
### 5. Inference Module
|
### 5. Inference Module
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,8 +91,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseStorage {
|
class BaseStorage {
|
||||||
+Dict segments
|
+MultiSegmentFetcher _fetcher
|
||||||
+List keys
|
+keys (property)
|
||||||
+load(load_path, tokenizer)
|
+load(load_path, tokenizer)
|
||||||
+fetch(begin, end, keys)
|
+fetch(begin, end, keys)
|
||||||
+__len__()
|
+__len__()
|
||||||
|
|
@ -145,7 +145,7 @@ classDiagram
|
||||||
+ModelConfig config
|
+ModelConfig config
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(model_type) decorator
|
+register(model_type) decorator
|
||||||
+get_model_class(model_type) Type
|
+get_component_class(model_type) Type
|
||||||
+from_pretrained(path, disable_random_init) nn.Module
|
+from_pretrained(path, disable_random_init) nn.Module
|
||||||
+save_pretrained(save_directory)
|
+save_pretrained(save_directory)
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
|
|
@ -214,7 +214,7 @@ classDiagram
|
||||||
+int dim
|
+int dim
|
||||||
+int max_len
|
+int max_len
|
||||||
+float base
|
+float base
|
||||||
+forward(x, position_ids=None) Tuple[Tensor, Tensor]
|
+forward(x, position_ids=None) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class Embedding {
|
class Embedding {
|
||||||
|
|
@ -225,13 +225,10 @@ classDiagram
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+List[int] stop_ids
|
|
||||||
+int bos_id
|
|
||||||
+int eos_id
|
|
||||||
+int pad_id
|
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
|
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||||
+set_chat_template(template)
|
+set_chat_template(template)
|
||||||
+load(path)
|
+load(path)
|
||||||
|
|
@ -325,6 +322,8 @@ classDiagram
|
||||||
+float clip_eps
|
+float clip_eps
|
||||||
+float kl_coef
|
+float kl_coef
|
||||||
+int group_size
|
+int group_size
|
||||||
|
+str reduction
|
||||||
|
+int sync_interval
|
||||||
+compute_loss(batch) Tensor
|
+compute_loss(batch) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -369,11 +368,6 @@ classDiagram
|
||||||
+on_step_begin(context)
|
+on_step_begin(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class SchedulerCallback {
|
|
||||||
+on_train_begin(context)
|
|
||||||
+on_batch_end(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
|
|
@ -409,8 +403,6 @@ classDiagram
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+InferenceScheduler scheduler
|
+InferenceScheduler scheduler
|
||||||
+int max_batch_size
|
|
||||||
+Optional int max_seq_len
|
|
||||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||||
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
||||||
|
|
@ -421,13 +413,12 @@ classDiagram
|
||||||
class InferenceScheduler {
|
class InferenceScheduler {
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+KVCache page_cache
|
+KVCache _page_cache
|
||||||
+int max_batch_size
|
+int max_batch_size
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+int max_prompt_len
|
+int max_prompt_len
|
||||||
+int page_size
|
+int page_size
|
||||||
+List waiting_queue
|
+TaskManager _task_mgr
|
||||||
+List active_tasks
|
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
|
|
@ -568,7 +559,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenerateResult {
|
class GenerateResult {
|
||||||
+List[str] tokens
|
+List[Tuple[int, str]] tokens
|
||||||
+List[str] results
|
+List[str] results
|
||||||
+List[bool] _done
|
+List[bool] _done
|
||||||
+append(token, idx)
|
+append(token, idx)
|
||||||
|
|
@ -643,7 +634,6 @@ classDiagram
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
CallbackFactory ..> TrainCallback : creates
|
CallbackFactory ..> TrainCallback : creates
|
||||||
TrainCallback <|-- GradientClippingCallback
|
TrainCallback <|-- GradientClippingCallback
|
||||||
TrainCallback <|-- SchedulerCallback
|
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = "1.3.4"
|
__version__ = "1.3.5"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue