diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 30d5738..683c508 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,7 +5,7 @@ Thank you for your interest in contributing! This document provides step-by-step ## Quick Start ```bash -git clone https://github.com/your-username/AstrAI.git +git clone https://github.com/ViperEkura/AstrAI.git cd AstrAI pip install -e ".[dev]" # install with dev dependencies (pytest, ruff) ``` diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index f32e29c..a57a338 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -8,8 +8,8 @@ classDiagram class BaseConfig { +to_dict() Dict +from_dict(d) Self - +from_json(path) Self - +to_json(path) + +from_file(path) Self + +to_file(path) } class BaseModelConfig { @@ -61,31 +61,32 @@ classDiagram } class ConfigFactory { - +Registry _registry + +Dict _entries +register(name) decorator +load(raw) BaseConfig } class InputConfig { - +str type - +str messages_key - +str prompt_key - +str response_key - +str text_key + +Optional[List[Dict]] sections + +Optional[Dict[str, Dict]] sources } class ProcessingConfig { +int max_seq_len +int min_chars +int max_chars - +bool deduplicate +Optional[int] max_items + +str packing_strategy + +int max_packed_len + +str truncation_mode } class OutputConfig { +Optional[str] domain_key +str storage_format +int max_tokens_per_shard + +Dict[str, str] dtype + +str position_ids_mode } class PipelineConfig { @@ -190,13 +191,13 @@ classDiagram } class StoreFactory { - +Registry _registry + +Dict _entries +register(name) decorator +create(storage_type) Store } class DatasetFactory { - +Registry _registry + +Dict _entries +register(name) decorator +create(train_type, window_size, stride) BaseDataset +load(train_type, load_path, window_size, stride, storage_type) BaseDataset @@ -219,7 +220,7 @@ classDiagram namespace model { class AutoModel { +BaseModelConfig config - +Registry _registry + +Dict _entries +register(name) decorator +get_component_class(name) Type +from_pretrained(path, disable_random_init, strict) nn.Module @@ -395,24 +396,17 @@ classDiagram } namespace factory { - class Registry { - +Dict _entries - +register(name, component_cls, category, priority) - +get(name) Type - +list_names() List[str] - } - class BaseFactory { - +Registry _registry - +register(name, category, priority) decorator + +Dict _entries + +register(name) decorator +create(name, *args, **kwargs) T +list_registered() list } class MaskBuilderFactory { - +Registry _registry + +Dict _entries +register(name) decorator - +create(input_type, config, tokenizer) BaseMaskBuilder + +create(name, *args, **kwargs) BaseMaskBuilder } } @@ -461,7 +455,7 @@ classDiagram } class StrategyFactory { - +Registry _registry + +Dict _entries +register(name) decorator +create(train_type, model, device, **kwargs) BaseStrategy } @@ -502,9 +496,9 @@ classDiagram } class SchedulerFactory { - +Registry _registry + +Dict _entries +register(name) decorator - +create(optimizer, schedule_type, **kwargs) BaseScheduler + +create(name, *args, **kwargs) BaseScheduler } class CosineScheduler { @@ -521,6 +515,13 @@ classDiagram +int t_mult } + class WSDScheduler { + +int warmup_steps + +int stable_steps + +int decay_steps + +float min_rate + } + class TrainCallback { <> +on_train_begin(context) @@ -581,7 +582,7 @@ classDiagram } class CallbackFactory { - +Registry _registry + +Dict _entries +register(name) decorator +create(name, **kwargs) TrainCallback } @@ -891,9 +892,9 @@ classDiagram +str yielded } - class app { - <> - +FastAPI app + class get_app { + <> + +get_app() FastAPI } } @@ -975,7 +976,7 @@ classDiagram } class ExecutorFactory { - +Registry _registry + +Dict _entries +register(name) decorator +create(parallel_mode, **kwargs) BaseExecutor } @@ -1018,6 +1019,7 @@ classDiagram BaseStrategy <|-- GRPOStrategy BaseScheduler <|-- CosineScheduler BaseScheduler <|-- SGDRScheduler + BaseScheduler <|-- WSDScheduler TrainCallback <|-- GradientClippingCallback TrainCallback <|-- GradientCheckpointingCallback TrainCallback <|-- CheckpointCallback @@ -1080,7 +1082,6 @@ classDiagram DecoderBlock *-- RMSNorm ChatCompletionRequest *-- ChatMessage MessagesRequest *-- AnthropicMessage - BaseFactory *-- Registry BaseExecutor *-- GradientState AccumOptimizer o-- GradientState AccumScheduler o-- GradientState @@ -1157,13 +1158,13 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| -| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) | +| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file) | | **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, SectionedMaskBuilder, Pipeline, filter_by_length, PackingStrategy, PackingStrategyFactory, PositionIdStrategy, PositionIdStrategyFactory, StoreWriter, StoreWriterFactory | Declarative JSON-driven data preprocessing | | **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | -| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow | +| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–WSDScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | | **astrai.factory** | Registry, BaseFactory[T] | Component registration | @@ -1174,7 +1175,7 @@ classDiagram | Pattern | Classes | Purpose | |---------|---------|---------| | **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation | -| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | +| **Registry** | `BaseFactory` | Component registration | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | | **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks | @@ -1197,7 +1198,7 @@ classDiagram 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` -9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` +9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`/`WSDScheduler` 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers diff --git a/assets/docs/params.md b/assets/docs/params.md index 2f663e4..65150f3 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -86,11 +86,12 @@ | Parameter | Description | Default | Used by | |-----------|-------------|---------|---------| | `--dpo_beta` | DPO beta value | 0.1 | `dpo` | -| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` | +| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.0 | `seq`, `sft` | | `--group_size` | GRPO group size | 4 | `grpo` | | `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` | | `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` | | `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` | +| `--neftune_alpha` | NEFTune noise alpha (0=disabled, typical: 5.0) | 0.0 | `sft` | ### Usage Example diff --git a/assets/docs/training.md b/assets/docs/training.md index 3dbaf0d..a885361 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -127,8 +127,9 @@ Keys: `prompts`, `responses`, `masks`, `rewards`. |------|-------|-------------| | Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` | | SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) | +| WSD | `WSDScheduler` | Warmup-Stable-Decay with sqrt cooldown | -Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler. +Created by `SchedulerFactory.create(schedule_type, optimizer, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`, `"wsd"`. Omit to use no scheduler. ## Gradient Checkpointing