diff --git a/README.md b/README.md index 955f44f..4c58d72 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,16 @@ For development dependencies: pip install -e ".[dev]" ``` +#### Download Pre-trained Model + +Download pre-trained model weights (1B bilingual checkpoint) to `params/`: + +```bash +python scripts/demo/download.py +``` + +Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`. + #### Train a Model ```bash diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index aa749b5..1b87880 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -71,6 +71,16 @@ pip install -e . pip install -e ".[dev]" ``` +#### 下载预训练模型 + +下载预训练模型权重(1B 双语检查点)到 `params/` 目录: + +```bash +python scripts/demo/download.py +``` + +或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。 + #### 训练模型 ```bash diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 906ac1e..313e4d2 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -88,7 +88,7 @@ flowchart LR - **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm - **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention) - **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection -- **`RotaryEmbedding`**: RoPE cos/sin cache +- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis) - **`RMSNorm`**: Layer normalization ### 4. Training Module @@ -104,22 +104,23 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers ``` on_train_begin on_epoch_begin - for each batch: - if iteration % accumulation_steps == 0: ← step phase - on_step_begin → optimizer.step() → zero_grad → on_step_end - ← batch phase - on_batch_begin → strategy(batch) → loss → backward → on_batch_end - iteration += 1 + for each accumulation window of batches: ← step phase + on_step_begin + for each batch in window: ← batch phase + on_batch_begin → strategy(batch) → loss → backward → on_batch_end + iteration += 1 + on_step_end + optimizer.step() → zero_grad on_epoch_end on_train_end ``` Key points: -- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches) -- `on_batch_*` wraps loss computation (fires every batch) -- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch -- `GradientClippingCallback` fires on `on_step_begin` +- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook +- `on_batch_*` fires every batch, wrapping loss computation +- `GradientClippingCallback` fires on `on_step_end` +- LR scheduler steps inline (no `SchedulerCallback` class) #### 4.3 Strategy (`strategy.py`) - **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing @@ -136,8 +137,7 @@ Key points: - **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations - **`ProgressBarCallback`**: tqdm progress display - **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/` -- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin` -- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end` +- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end` ### 5. Inference Module diff --git a/assets/docs/design.md b/assets/docs/design.md index c5fdffb..550587a 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -91,8 +91,8 @@ classDiagram } class BaseStorage { - +Dict segments - +List keys + +MultiSegmentFetcher _fetcher + +keys (property) +load(load_path, tokenizer) +fetch(begin, end, keys) +__len__() @@ -145,7 +145,7 @@ classDiagram +ModelConfig config +Registry _registry +register(model_type) decorator - +get_model_class(model_type) Type + +get_component_class(model_type) Type +from_pretrained(path, disable_random_init) nn.Module +save_pretrained(save_directory) +to(*args, **kwargs) Self @@ -214,7 +214,7 @@ classDiagram +int dim +int max_len +float base - +forward(x, position_ids=None) Tuple[Tensor, Tensor] + +forward(x, position_ids=None) Tensor } class Embedding { @@ -225,13 +225,10 @@ classDiagram namespace tokenize { class AutoTokenizer { - +List[int] stop_ids - +int bos_id - +int eos_id - +int pad_id +vocab_size int +encode(tokens, out_ids, add_special_tokens) List[int] +decode(tokens, skip_special_tokens) str + +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) +apply_chat_template(messages, tokenize) Union[str, List[int]] +set_chat_template(template) +load(path) @@ -325,6 +322,8 @@ classDiagram +float clip_eps +float kl_coef +int group_size + +str reduction + +int sync_interval +compute_loss(batch) Tensor } @@ -369,11 +368,6 @@ classDiagram +on_step_begin(context) } - class SchedulerCallback { - +on_train_begin(context) - +on_batch_end(context) - } - class CheckpointCallback { +str save_dir +int interval @@ -409,8 +403,6 @@ classDiagram +nn.Module model +AutoTokenizer tokenizer +InferenceScheduler scheduler - +int max_batch_size - +Optional int max_seq_len +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]] +generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator @@ -421,13 +413,12 @@ classDiagram class InferenceScheduler { +nn.Module model +AutoTokenizer tokenizer - +KVCache page_cache + +KVCache _page_cache +int max_batch_size +int max_seq_len +int max_prompt_len +int page_size - +List waiting_queue - +List active_tasks + +TaskManager _task_mgr +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +remove_task(task_id) +start() @@ -568,7 +559,7 @@ classDiagram } class GenerateResult { - +List[str] tokens + +List[Tuple[int, str]] tokens +List[str] results +List[bool] _done +append(token, idx) @@ -643,7 +634,6 @@ classDiagram BaseScheduler <|-- SGDRScheduler CallbackFactory ..> TrainCallback : creates TrainCallback <|-- GradientClippingCallback - TrainCallback <|-- SchedulerCallback TrainCallback <|-- CheckpointCallback TrainCallback <|-- ProgressBarCallback TrainCallback <|-- MetricLoggerCallback diff --git a/astrai/__init__.py b/astrai/__init__.py index c8e5fa9..430316a 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.4" +__version__ = "1.3.5" __author__ = "ViperEkura" from astrai.config import (