Compare commits
9 Commits
65ab69543b
...
0a708fff24
| Author | SHA1 | Date |
|---|---|---|
|
|
0a708fff24 | |
|
|
6e150ea6d0 | |
|
|
cb8dcb97ea | |
|
|
2d5dc93b3d | |
|
|
4145d35e3c | |
|
|
34c6c45bd6 | |
|
|
e9def84ce7 | |
|
|
836e02a166 | |
|
|
b558e61f63 |
|
|
@ -22,7 +22,8 @@ classDiagram
|
|||
+int n_layers
|
||||
+float norm_eps
|
||||
+int dim_ffn
|
||||
+bool tie_weight
|
||||
+Optional[bool] tie_weight
|
||||
+Optional[dict] rope_scaling
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+str attn_type
|
||||
|
|
@ -52,6 +53,7 @@ classDiagram
|
|||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Optional[dict] rope_scaling
|
||||
+Optional[str] pooling_type
|
||||
+Optional[bool] normalize_embeddings
|
||||
}
|
||||
|
|
@ -80,6 +82,7 @@ classDiagram
|
|||
+str log_dir
|
||||
+int log_interval
|
||||
+List[str] metrics
|
||||
+Optional[LoRAConfig] lora
|
||||
+int random_seed
|
||||
+int num_workers
|
||||
+Optional[int] prefetch_factor
|
||||
|
|
@ -104,7 +107,7 @@ classDiagram
|
|||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+Optional[BaseStorage] storage
|
||||
+Optional[Store] storage
|
||||
+load(load_path, storage_type, tokenizer)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
|
|
@ -126,38 +129,29 @@ classDiagram
|
|||
+__getitem__(index) Dict
|
||||
}
|
||||
|
||||
class BaseSegmentFetcher {
|
||||
+List[Tensor] segments
|
||||
+List[int] cum_lengths
|
||||
+int total_length
|
||||
+fetch_data(begin_idx, end_idx) Tensor
|
||||
}
|
||||
|
||||
class BaseStorage {
|
||||
+MultiSegmentFetcher _fetcher
|
||||
class Store {
|
||||
+Dict[str, List[Tensor]] _data
|
||||
+Dict[str, List[int]] _cum
|
||||
+int _length
|
||||
+keys (property)
|
||||
+load(load_path, tokenizer)
|
||||
+load(path, tokenizer)
|
||||
+fetch(begin, end, keys)
|
||||
+__len__()
|
||||
-_fetch_key(key, begin, end) Tensor
|
||||
-_normalize(raw)
|
||||
}
|
||||
|
||||
class H5Storage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
class H5Store {
|
||||
+load(path, tokenizer)
|
||||
}
|
||||
|
||||
class JSONStorage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
class JSONStore {
|
||||
+load(path, tokenizer)
|
||||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict multi_fetchers
|
||||
+List multi_keys
|
||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||
+fetch_data(begin_idx, end_idx) Dict
|
||||
class MmapStore {
|
||||
+List _mmap_refs
|
||||
+load(path, tokenizer)
|
||||
}
|
||||
|
||||
class ResumableDistributedSampler {
|
||||
|
|
@ -165,10 +159,10 @@ classDiagram
|
|||
+int iter
|
||||
}
|
||||
|
||||
class StorageFactory {
|
||||
class StoreFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(storage_type) BaseStorage
|
||||
+create(storage_type) Store
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
|
|
@ -457,16 +451,15 @@ classDiagram
|
|||
+on_train_end(context)
|
||||
+on_epoch_begin(context)
|
||||
+on_epoch_end(context)
|
||||
+on_step_begin(context)
|
||||
+on_step_end(context)
|
||||
+on_batch_begin(context)
|
||||
+on_batch_end(context)
|
||||
+on_optimizer_step(context)
|
||||
+on_error(context)
|
||||
}
|
||||
|
||||
class GradientClippingCallback {
|
||||
+float max_grad_norm
|
||||
+on_step_begin(context)
|
||||
+on_optimizer_step(context)
|
||||
}
|
||||
|
||||
class GradientCheckpointingCallback {
|
||||
|
|
@ -512,7 +505,7 @@ classDiagram
|
|||
|
||||
class ValidationCallback {
|
||||
+_run_validation(context)
|
||||
+on_step_end(context)
|
||||
+on_optimizer_step(context)
|
||||
}
|
||||
|
||||
class CallbackFactory {
|
||||
|
|
@ -747,56 +740,58 @@ classDiagram
|
|||
+str model
|
||||
+List[AnthropicMessage] messages
|
||||
+Optional[str] system
|
||||
+float temperature
|
||||
+float top_p
|
||||
+int top_k
|
||||
+Optional[float] temperature
|
||||
+Optional[float] top_p
|
||||
+Optional[int] top_k
|
||||
+int max_tokens
|
||||
+bool stream
|
||||
+Optional[bool] stream
|
||||
+Optional[List[str]] stop_sequences
|
||||
}
|
||||
|
||||
class ProtocolHandler {
|
||||
class ResponseBuilder {
|
||||
<<abstract>>
|
||||
+prepare(request, engine) Tuple[str, GenContext, List[str]]
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class OpenAIResponseBuilder {
|
||||
+prepare(request, engine) Tuple
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class AnthropicResponseBuilder {
|
||||
+prepare(request, engine) Tuple
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class ProtocolHandler {
|
||||
+request
|
||||
+engine
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+get_stop_sequences() List[str]
|
||||
+create_stop_checker() StopChecker
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_stream_token(ctx, token) str
|
||||
+format_stream_end(ctx) List[str]
|
||||
+format_non_stream_response(ctx, content) Dict
|
||||
+builder: ResponseBuilder
|
||||
+handle() Union[StreamingResponse, Dict]
|
||||
}
|
||||
|
||||
class OpenAIHandler {
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
}
|
||||
|
||||
class AnthropicHandler {
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
-_handle_stream(agen, ctx, stops) StreamingResponse
|
||||
-_handle_non_stream(agen, ctx, stops) Dict
|
||||
}
|
||||
|
||||
class StopChecker {
|
||||
+has_sequences (property) bool
|
||||
+check(text) Optional[str]
|
||||
+trim(text, matched) str
|
||||
}
|
||||
|
||||
class StreamContext {
|
||||
class GenContext {
|
||||
+str resp_id
|
||||
+int created
|
||||
+str model
|
||||
+int prompt_tokens
|
||||
+int completion_tokens
|
||||
+str accumulated
|
||||
+Optional[str] stop_matched
|
||||
+str last_yield_trimmed
|
||||
}
|
||||
|
||||
class app {
|
||||
|
|
@ -876,6 +871,11 @@ classDiagram
|
|||
+unwrap_model(model) nn.Module
|
||||
}
|
||||
|
||||
class FSDPExecutor {
|
||||
+_prepare_model(model) nn.Module
|
||||
+unwrap_model(model) nn.Module
|
||||
}
|
||||
|
||||
class ExecutorFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
|
|
@ -911,12 +911,14 @@ classDiagram
|
|||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
TrainCallback <|-- ValidationCallback
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
BaseDataset <|-- GRPODataset
|
||||
BaseStorage <|-- H5Storage
|
||||
BaseStorage <|-- JSONStorage
|
||||
Store <|-- H5Store
|
||||
Store <|-- JSONStore
|
||||
Store <|-- MmapStore
|
||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||
BaseSamplingStrategy <|-- TopKStrategy
|
||||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
|
|
@ -936,20 +938,19 @@ classDiagram
|
|||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
BaseFactory <|-- CallbackFactory
|
||||
BaseFactory <|-- StorageFactory
|
||||
BaseFactory <|-- StoreFactory
|
||||
BaseFactory <|-- ExecutorFactory
|
||||
BaseFactory <|-- ConfigFactory
|
||||
BaseExecutor <|-- NoneExecutor
|
||||
BaseExecutor <|-- DDPExecutor
|
||||
ProtocolHandler <|-- OpenAIHandler
|
||||
ProtocolHandler <|-- AnthropicHandler
|
||||
BaseExecutor <|-- FSDPExecutor
|
||||
ResponseBuilder <|-- OpenAIResponseBuilder
|
||||
ResponseBuilder <|-- AnthropicResponseBuilder
|
||||
|
||||
%% --- Composition (strong ownership, part destroyed with whole) ---
|
||||
KVCache *-- PagePool
|
||||
KVCache *-- Storage
|
||||
KVCache *-- TaskTable
|
||||
PagePool *-- Allocator
|
||||
PagePool *-- PrefixCache
|
||||
InferenceEngine *-- InferenceScheduler
|
||||
InferenceScheduler *-- KVCache
|
||||
InferenceScheduler *-- Executor
|
||||
|
|
@ -963,7 +964,6 @@ classDiagram
|
|||
DecoderBlock *-- RMSNorm
|
||||
ChatCompletionRequest *-- ChatMessage
|
||||
MessagesRequest *-- AnthropicMessage
|
||||
AutoTokenizer *-- ChatTemplate
|
||||
BaseFactory *-- Registry
|
||||
BaseExecutor *-- GradientState
|
||||
AccumOptimizer o-- GradientState
|
||||
|
|
@ -971,6 +971,9 @@ classDiagram
|
|||
|
||||
%% --- Aggregation (weak ownership) ---
|
||||
AutoModel o-- BaseModelConfig
|
||||
AutoTokenizer o-- ChatTemplate
|
||||
PagePool o-- Allocator
|
||||
PagePool o-- PrefixCache
|
||||
Trainer o-- TrainCallback
|
||||
TrainContext o-- BaseStrategy
|
||||
TrainContext o-- BaseScheduler
|
||||
|
|
@ -978,7 +981,7 @@ classDiagram
|
|||
TrainContext o-- BaseExecutor
|
||||
KvcacheView o-- Storage
|
||||
SamplingPipeline o-- BaseSamplingStrategy
|
||||
BaseDataset o-- BaseStorage
|
||||
BaseDataset o-- Store
|
||||
|
||||
%% --- Dependency (uses temporarily) ---
|
||||
TrainConfig ..> BaseStrategy : selects
|
||||
|
|
@ -992,12 +995,14 @@ classDiagram
|
|||
FFNFactory ..> DeepSeekMoE : creates
|
||||
DecoderBlock ..> AttnFactory : uses
|
||||
DecoderBlock ..> FFNFactory : uses
|
||||
StorageFactory ..> H5Storage : creates
|
||||
StorageFactory ..> JSONStorage : creates
|
||||
StoreFactory ..> H5Store : creates
|
||||
StoreFactory ..> JSONStore : creates
|
||||
StoreFactory ..> MmapStore : creates
|
||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||
ConfigFactory ..> EncoderConfig : creates
|
||||
ExecutorFactory ..> NoneExecutor : creates
|
||||
ExecutorFactory ..> DDPExecutor : creates
|
||||
ExecutorFactory ..> FSDPExecutor : creates
|
||||
TrainContextBuilder ..> ExecutorFactory : creates
|
||||
Trainer ..> TrainContextBuilder : uses
|
||||
TrainContextBuilder ..> TrainContext : creates
|
||||
|
|
@ -1009,10 +1014,10 @@ classDiagram
|
|||
KVCache ..> KvcacheView : binds
|
||||
InferenceEngine ..> GenerationRequest : uses
|
||||
InferenceEngine ..> GenerateResult : creates
|
||||
OpenAIHandler ..> ChatCompletionRequest : receives
|
||||
AnthropicHandler ..> MessagesRequest : receives
|
||||
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
|
||||
AnthropicResponseBuilder ..> MessagesRequest : receives
|
||||
ProtocolHandler ..> StopChecker : creates
|
||||
ProtocolHandler ..> StreamContext : creates
|
||||
ProtocolHandler ..> GenContext : creates
|
||||
|
||||
%% --- Association (general usage) ---
|
||||
Trainer --> TrainConfig
|
||||
|
|
@ -1025,8 +1030,6 @@ classDiagram
|
|||
Executor --> AutoModel
|
||||
Executor --> AutoTokenizer
|
||||
TaskManager --> AutoTokenizer
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||
ResumableDistributedSampler --> BaseDataset
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -1036,13 +1039,13 @@ classDiagram
|
|||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, StopChecker, StreamContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
|
||||
|
||||
|
|
@ -1050,17 +1053,17 @@ classDiagram
|
|||
|
||||
| Pattern | Classes | Purpose |
|
||||
|---------|---------|---------|
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
|
||||
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
||||
| **Builder** | `TrainContextBuilder` | Chain-building training context |
|
||||
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
||||
| **Context** | `TrainContext` | Unified training state bag |
|
||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||
|
||||
|
|
@ -1069,10 +1072,10 @@ classDiagram
|
|||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
|
||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
|
|
|
|||
|
|
@ -5,21 +5,22 @@ This document describes the data pipeline: from raw text to model input tensors.
|
|||
## Overview
|
||||
|
||||
```
|
||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
|
||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
|
||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||
|
||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||
|
||||
```
|
||||
StorageFactory.create("h5") → H5Storage
|
||||
StorageFactory.create("json") → JSONStorage
|
||||
StoreFactory.create("h5") → H5Store
|
||||
StoreFactory.create("json") → JSONStore
|
||||
StoreFactory.create("bin") → MmapStore
|
||||
```
|
||||
|
||||
Both support shared memory via `.share_memory_()`.
|
||||
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||
|
||||
## Data Keys by Training Type
|
||||
|
||||
|
|
@ -33,14 +34,14 @@ Both support shared memory via `.share_memory_()`.
|
|||
## Dataset Architecture
|
||||
|
||||
```
|
||||
DatasetFactory.load(train_type, path, window_size, stride)
|
||||
→ StorageFactory.create(detect_format(path))
|
||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
||||
→ StoreFactory.create(detect_format(path))
|
||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ sliding window [begin, end) via get_index(idx)
|
||||
```
|
||||
|
||||
`window_size` = max input length, `stride` = step between consecutive samples.
|
||||
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
|
||||
|
||||
## Sampler
|
||||
|
||||
|
|
|
|||
|
|
@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
|||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||
`sample()` is a convenience shortcut for one-shot usage.
|
||||
|
||||
## Protocol Handlers (Template Method)
|
||||
## Protocol Handlers (Strategy Pattern)
|
||||
|
||||
```python
|
||||
class ProtocolHandler(ABC):
|
||||
def handle(self):
|
||||
ctx = StreamContext(...)
|
||||
class ProtocolHandler: # concrete orchestrator
|
||||
def handle(self, request):
|
||||
prompt, ctx, stops = builder.prepare(request, engine)
|
||||
agen = engine.generate_async(prompt, ...)
|
||||
if stream: self._handle_stream(agen, ctx)
|
||||
else: self._handle_non_stream(agen, ctx)
|
||||
if stream: self._handle_stream(agen, ctx, stops)
|
||||
else: self._handle_non_stream(agen, ctx, stops)
|
||||
```
|
||||
|
||||
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
|
||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||
|
||||
`OpenAIHandler` → `/v1/chat/completions`, `AnthropicHandler` → `/v1/messages`.
|
||||
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
||||
|
||||
Adding a protocol = one builder file, no handler subclassing needed.
|
||||
|
||||
## Engine & GenerateResult
|
||||
|
||||
|
|
@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
|||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (0.0–2.0) |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||
| `top_k` | int | 50 | Top-k count |
|
||||
| `max_tokens` | int | None | Max generation length |
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@
|
|||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
|
||||
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
||||
| `--device_type` | Device type | cuda |
|
||||
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||
|
||||
|
|
|
|||
|
|
@ -82,8 +82,7 @@ on_train_begin
|
|||
on_optimizer_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
scheduler.step() # called every iteration
|
||||
scheduler.step()
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
|
@ -190,7 +189,7 @@ context = (
|
|||
```
|
||||
|
||||
- Loads checkpoint weights if provided
|
||||
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
|
||||
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ def required(**kw):
|
|||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model: nn.Module = field(
|
||||
default=None, metadata=required(help="Model for training.")
|
||||
model_fn: Callable[[], nn.Module] = field(
|
||||
default=None, metadata=required(help="Model factory for training.")
|
||||
)
|
||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||
dataset: Dataset = field(
|
||||
|
|
|
|||
|
|
@ -4,15 +4,17 @@ from astrai.dataset.dataset import (
|
|||
)
|
||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||
from astrai.dataset.storage import (
|
||||
BaseSegmentFetcher,
|
||||
BaseStorage,
|
||||
H5Storage,
|
||||
JSONStorage,
|
||||
MultiSegmentFetcher,
|
||||
StorageFactory,
|
||||
H5Store,
|
||||
JSONStore,
|
||||
MmapStore,
|
||||
Store,
|
||||
StoreFactory,
|
||||
detect_format,
|
||||
json_to_bin,
|
||||
load_bin,
|
||||
load_h5,
|
||||
load_json,
|
||||
save_bin,
|
||||
save_h5,
|
||||
save_json,
|
||||
)
|
||||
|
|
@ -20,16 +22,18 @@ from astrai.dataset.storage import (
|
|||
__all__ = [
|
||||
"BaseDataset",
|
||||
"DatasetFactory",
|
||||
"BaseSegmentFetcher",
|
||||
"MultiSegmentFetcher",
|
||||
"BaseStorage",
|
||||
"H5Storage",
|
||||
"JSONStorage",
|
||||
"StorageFactory",
|
||||
"Store",
|
||||
"StoreFactory",
|
||||
"H5Store",
|
||||
"JSONStore",
|
||||
"MmapStore",
|
||||
"detect_format",
|
||||
"save_h5",
|
||||
"load_h5",
|
||||
"save_json",
|
||||
"load_json",
|
||||
"save_bin",
|
||||
"load_bin",
|
||||
"json_to_bin",
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from torch import Tensor
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.dataset.storage import (
|
||||
BaseStorage,
|
||||
StorageFactory,
|
||||
Store,
|
||||
StoreFactory,
|
||||
detect_format,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
|
|
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
|
|||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.storage: Optional[BaseStorage] = None
|
||||
self.storage: Optional[Store] = None
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
|
|
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
|
|||
"""
|
||||
if storage_type is None:
|
||||
storage_type = detect_format(load_path)
|
||||
self.storage = StorageFactory.create(storage_type)
|
||||
self.storage = StoreFactory.create(storage_type)
|
||||
self._load_path = load_path
|
||||
self.storage.load(load_path, tokenizer=tokenizer)
|
||||
self._validate_keys()
|
||||
|
|
@ -148,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, dataset_cls: type) -> None:
|
||||
def _validate_component(cls, dataset_cls: type):
|
||||
"""Validate that the dataset class inherits from BaseDataset."""
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,20 @@
|
|||
"""Storage backends for different data formats.
|
||||
|
||||
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
||||
a uniform interface for data access and length observation via fetchers.
|
||||
Layers:
|
||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
|
||||
return Dict[str, List[Tensor]] — format-specific, no state
|
||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||
Dict[str, List[Tensor]] per key via _normalize(),
|
||||
fetch() uses bisect across segments — no forced concat
|
||||
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||
|
||||
Key properties:
|
||||
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||
datasets larger than RAM
|
||||
- Explicit length: _length = min(total elements across keys), set at load,
|
||||
__len__ returns O(1)
|
||||
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||
workers share OS page-cache pages
|
||||
"""
|
||||
|
||||
import bisect
|
||||
|
|
@ -12,6 +25,7 @@ from pathlib import Path
|
|||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
|
@ -104,6 +118,38 @@ def load_json(
|
|||
return tensor_group
|
||||
|
||||
|
||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
meta = {}
|
||||
for key, tensors in tensor_group.items():
|
||||
cat = torch.cat(tensors, dim=0)
|
||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||
save_json(meta, os.path.join(file_path, "meta.json"))
|
||||
|
||||
|
||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||
meta = load_json(os.path.join(file_path, "meta.json"))
|
||||
segments: Dict[str, List[Tensor]] = {}
|
||||
for key, info in meta.items():
|
||||
arr = np.memmap(
|
||||
os.path.join(file_path, f"{key}.bin"),
|
||||
dtype=info["dtype"],
|
||||
mode="r",
|
||||
shape=tuple(info["shape"]),
|
||||
)
|
||||
segments[key] = [torch.from_numpy(arr)]
|
||||
return segments
|
||||
|
||||
|
||||
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
|
||||
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
|
||||
merged = {}
|
||||
for key, tensors in segments.items():
|
||||
merged[key] = [torch.cat(tensors, dim=0)]
|
||||
save_bin(bin_path, merged)
|
||||
|
||||
|
||||
def detect_format(load_path: str) -> str:
|
||||
"""Auto-detect storage format from files in the directory.
|
||||
|
||||
|
|
@ -111,7 +157,7 @@ def detect_format(load_path: str) -> str:
|
|||
load_path: Directory or file path
|
||||
|
||||
Returns:
|
||||
Format string ("h5" or "json")
|
||||
Format string ("h5", "bin", or "json")
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If no supported data files are found
|
||||
|
|
@ -128,166 +174,118 @@ def detect_format(load_path: str) -> str:
|
|||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||
if h5_files:
|
||||
return "h5"
|
||||
bin_files = list(root.rglob("*.bin"))
|
||||
if bin_files and (root / "meta.json").exists():
|
||||
return "bin"
|
||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||
if json_files:
|
||||
return "json"
|
||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
class Store(ABC):
|
||||
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
||||
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
Each key maps to one or more tensor segments (no forced concatenation).
|
||||
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
||||
total element count across all keys.
|
||||
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
|
||||
self.total_length = total
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx)."""
|
||||
if not (
|
||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||
):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
result_segments = []
|
||||
|
||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
result_segments.append(self.segments[i][start:end])
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys."""
|
||||
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
if not self.multi_fetchers:
|
||||
return 0
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||
) -> Dict:
|
||||
"""Fetch data for specific keys."""
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
"""Abstract storage backend for loading and dispatching data.
|
||||
|
||||
Storage encapsulates format-specific loading and provides a uniform
|
||||
interface for data access and length observation. Subclasses handle
|
||||
different data formats (HDF5, JSON, etc.) while exposing the same
|
||||
fetch interface.
|
||||
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
||||
via ``_normalize()``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||
self._data: Dict[str, List[Tensor]] = {}
|
||||
self._cum: Dict[str, List[int]] = {}
|
||||
self._length: int = 0
|
||||
|
||||
@abstractmethod
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
"""Load data from the given path into internal fetcher."""
|
||||
def load(self, path: str, tokenizer=None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Total number of raw elements (tokens) in storage."""
|
||||
if self._fetcher is None:
|
||||
return 0
|
||||
return len(self._fetcher)
|
||||
|
||||
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
||||
"""Fetch data for the given keys and index range.
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
Returns:
|
||||
Tensor if single key, Dict[str, Tensor] if multiple keys
|
||||
"""
|
||||
if self._fetcher is None:
|
||||
raise RuntimeError("Storage not loaded")
|
||||
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
"""Return the data keys available in this storage."""
|
||||
if self._fetcher is None:
|
||||
return []
|
||||
return self._fetcher.multi_keys
|
||||
return list(self._data.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
begin: int,
|
||||
end: int,
|
||||
keys: Union[str, List[str]],
|
||||
):
|
||||
if not self._data:
|
||||
raise RuntimeError("Store not loaded")
|
||||
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
||||
raise ValueError(
|
||||
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
||||
)
|
||||
if isinstance(keys, str):
|
||||
return self._fetch_key(keys, begin, end)
|
||||
return {k: self._fetch_key(k, begin, end) for k in keys}
|
||||
|
||||
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
||||
"""Fetch slice [begin, end) across potentially multiple segments."""
|
||||
segments = self._data[key]
|
||||
cum = self._cum[key]
|
||||
seg_start = bisect.bisect_right(cum, begin)
|
||||
seg_end = bisect.bisect_left(cum, end)
|
||||
|
||||
results = []
|
||||
for i in range(seg_start, seg_end + 1):
|
||||
prev = cum[i - 1] if i > 0 else 0
|
||||
s = max(begin - prev, 0)
|
||||
e = min(end - prev, segments[i].shape[0])
|
||||
results.append(segments[i][s:e])
|
||||
|
||||
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
||||
|
||||
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
||||
"""Register segments and pre-compute cumulative lengths.
|
||||
|
||||
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
||||
large datasets. Sets ``self._length`` to the minimum total
|
||||
element count across all keys.
|
||||
"""
|
||||
for key, tensors in raw.items():
|
||||
self._data[key] = tensors
|
||||
cum = []
|
||||
total = 0
|
||||
for t in tensors:
|
||||
total += t.shape[0]
|
||||
cum.append(total)
|
||||
self._cum[key] = cum
|
||||
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
|
||||
|
||||
|
||||
class StorageFactory(BaseFactory["BaseStorage"]):
|
||||
"""Factory for creating storage backends by type name.
|
||||
class StoreFactory(BaseFactory["Store"]):
|
||||
"""Factory for creating Store instances by type name.
|
||||
|
||||
Example:
|
||||
@StorageFactory.register("custom")
|
||||
class CustomStorage(BaseStorage):
|
||||
Example::
|
||||
|
||||
@StoreFactory.register("custom")
|
||||
class CustomStore(Store):
|
||||
...
|
||||
|
||||
storage = StorageFactory.create("custom")
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, storage_cls: type) -> None:
|
||||
if not issubclass(storage_cls, BaseStorage):
|
||||
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
||||
def _validate_component(cls, store_cls: type):
|
||||
if not issubclass(store_cls, Store):
|
||||
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
||||
|
||||
|
||||
@StorageFactory.register("h5")
|
||||
class H5Storage(BaseStorage):
|
||||
@StoreFactory.register("h5")
|
||||
class H5Store(Store):
|
||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_h5(load_path)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
def load(self, path: str, tokenizer=None):
|
||||
self._normalize(load_h5(path))
|
||||
|
||||
|
||||
@StorageFactory.register("json")
|
||||
class JSONStorage(BaseStorage):
|
||||
@StoreFactory.register("json")
|
||||
class JSONStore(Store):
|
||||
"""JSON-based storage backend.
|
||||
|
||||
Supports two modes:
|
||||
|
|
@ -296,6 +294,28 @@ class JSONStorage(BaseStorage):
|
|||
callable (str -> List[int]) at load time.
|
||||
"""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_json(load_path, tokenizer=tokenizer)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
def load(self, path: str, tokenizer=None):
|
||||
self._normalize(load_json(path, tokenizer=tokenizer))
|
||||
|
||||
|
||||
@StoreFactory.register("bin")
|
||||
class MmapStore(Store):
|
||||
"""Memory-mapped binary storage backend.
|
||||
|
||||
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
||||
No per-process memory duplication — all DataLoader workers share the
|
||||
same OS page-cache pages.
|
||||
|
||||
Format on disk::
|
||||
|
||||
data_root/
|
||||
meta.json # {key: {shape, dtype}, ...}
|
||||
<key>.bin # raw numpy array, one per key
|
||||
"""
|
||||
|
||||
def load(self, path: str, tokenizer=None):
|
||||
self._mmap_refs = []
|
||||
raw = load_bin(path)
|
||||
self._normalize(raw)
|
||||
for tensors in self._data.values():
|
||||
self._mmap_refs.extend(tensors)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class Registry:
|
|||
component_cls: Type,
|
||||
category: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
):
|
||||
"""Register a component class with optional category and priority."""
|
||||
if name in self._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
|
|
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
|
|||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||
def _validate_component(cls, component_cls: Type[T]):
|
||||
"""Validate that the component class is valid for this factory.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class Allocator:
|
|||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||
def free(self, idx: int, keep_cached: bool = False):
|
||||
with self._lock:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
|
|
@ -51,7 +51,7 @@ class Allocator:
|
|||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
def inc_ref(self, idx: int):
|
||||
with self._lock:
|
||||
self._refs[idx] += 1
|
||||
self._lru.pop(idx, None)
|
||||
|
|
@ -60,7 +60,7 @@ class Allocator:
|
|||
with self._lock:
|
||||
return self._refs[idx]
|
||||
|
||||
def touch(self, idx: int) -> None:
|
||||
def touch(self, idx: int):
|
||||
with self._lock:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ class PrefixCache:
|
|||
self._hash_to_page: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def evict(self, idx: int) -> None:
|
||||
def evict(self, idx: int):
|
||||
with self._lock:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
|
|
@ -96,9 +96,7 @@ class PrefixCache:
|
|||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
with self._lock:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
|
|
@ -127,13 +125,13 @@ class PagePool:
|
|||
def alloc(self) -> int:
|
||||
return self._alloc.alloc()
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
def free(self, idx: int):
|
||||
keep = self._prefix.has_page(idx)
|
||||
self._alloc.free(idx, keep_cached=keep)
|
||||
if not keep:
|
||||
self._prefix.evict(idx)
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
def inc_ref(self, idx: int):
|
||||
self._alloc.inc_ref(idx)
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
|
|
@ -142,9 +140,7 @@ class PagePool:
|
|||
self._alloc.touch(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||
|
||||
|
||||
|
|
@ -157,7 +153,7 @@ class TaskTable:
|
|||
self._cached: Dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||
def set(self, task_id: str, page_table: List[int], cached: int):
|
||||
with self._lock:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
|
@ -220,7 +216,7 @@ class Storage:
|
|||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
):
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
|
|
@ -286,7 +282,7 @@ class KvcacheView:
|
|||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
||||
start_pos = self._total_len - k.size(1)
|
||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
|
|
@ -339,7 +335,7 @@ class KVCache:
|
|||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str) -> None:
|
||||
def task_free(self, task_id: str):
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self._pool.free(idx)
|
||||
|
|
@ -359,7 +355,7 @@ class KVCache:
|
|||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
) -> None:
|
||||
):
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
|
|
|
|||
|
|
@ -29,9 +29,7 @@ class Executor:
|
|||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
def execute_prefill(
|
||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||
) -> None:
|
||||
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
||||
if start_pos >= prompt_len:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -75,14 +75,14 @@ class InferenceScheduler:
|
|||
def add_task(self, prompt: str, **kwargs) -> str:
|
||||
return self._task_mgr.add_task(prompt, **kwargs)
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
def remove_task(self, task_id: str):
|
||||
for task in self._task_mgr.remove_task(task_id):
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self._task_mgr.get_stats()
|
||||
|
||||
def _run_generation_loop(self) -> None:
|
||||
def _run_generation_loop(self):
|
||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||
try:
|
||||
while self._running:
|
||||
|
|
@ -186,14 +186,14 @@ class InferenceScheduler:
|
|||
self._task_mgr.clear_queues()
|
||||
raise
|
||||
|
||||
def start(self) -> None:
|
||||
def start(self):
|
||||
if not self._running:
|
||||
self._running = True
|
||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||
t.start()
|
||||
self._loop_thread = t
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self):
|
||||
self._running = False
|
||||
self._task_mgr.wake()
|
||||
if hasattr(self, "_loop_thread"):
|
||||
|
|
|
|||
|
|
@ -172,12 +172,12 @@ class TaskManager:
|
|||
to_add.append(self.waiting_queue.popleft())
|
||||
return to_add
|
||||
|
||||
def activate(self, task: Task) -> None:
|
||||
def activate(self, task: Task):
|
||||
task.status = TaskStatus.RUNNING
|
||||
with self._lock:
|
||||
self.active_tasks.append(task)
|
||||
|
||||
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||
def return_to_waiting(self, tasks: List[Task]):
|
||||
with self._lock:
|
||||
for task in reversed(tasks):
|
||||
self.waiting_queue.appendleft(task)
|
||||
|
|
@ -185,7 +185,7 @@ class TaskManager:
|
|||
def has_work(self) -> bool:
|
||||
return bool(self.active_tasks or self.waiting_queue)
|
||||
|
||||
def wait_for_tasks(self, timeout: float = 1.0) -> None:
|
||||
def wait_for_tasks(self, timeout: float = 1.0):
|
||||
self._task_event.clear()
|
||||
self._task_event.wait(timeout=timeout)
|
||||
|
||||
|
|
@ -197,10 +197,10 @@ class TaskManager:
|
|||
with self._lock:
|
||||
return list(self.waiting_queue)
|
||||
|
||||
def clear_queues(self) -> None:
|
||||
def clear_queues(self):
|
||||
with self._lock:
|
||||
self.waiting_queue.clear()
|
||||
self.active_tasks.clear()
|
||||
|
||||
def wake(self) -> None:
|
||||
def wake(self):
|
||||
self._task_event.set()
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class GenerateResult:
|
|||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def wait_completion(self, timeout: float = 300.0) -> None:
|
||||
def wait_completion(self, timeout: float = 300.0):
|
||||
with self._cond:
|
||||
if not self._cond.wait_for(
|
||||
lambda: self._completed >= self._total, timeout=timeout
|
||||
|
|
@ -281,7 +281,7 @@ class InferenceEngine:
|
|||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.scheduler.get_stats()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
def shutdown(self):
|
||||
self.scheduler.stop()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
|
|||
|
||||
@contextmanager
|
||||
def _disable_random_init(enable: bool = True):
|
||||
init_functions = [
|
||||
if not enable:
|
||||
yield
|
||||
return
|
||||
|
||||
names = (
|
||||
"xavier_normal_",
|
||||
"xavier_uniform_",
|
||||
"kaiming_normal_",
|
||||
|
|
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
|
|||
"constant_",
|
||||
"normal_",
|
||||
"uniform_",
|
||||
]
|
||||
original_funcs = {}
|
||||
for name in init_functions:
|
||||
if enable and hasattr(nn.init, name):
|
||||
original_funcs[name] = getattr(nn.init, name)
|
||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||
)
|
||||
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
||||
for n in orig:
|
||||
setattr(nn.init, n, lambda *a, **kw: None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if enable:
|
||||
for name, orig_func in original_funcs.items():
|
||||
setattr(nn.init, name, orig_func)
|
||||
for n, fn in orig.items():
|
||||
setattr(nn.init, n, fn)
|
||||
|
||||
|
||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||
|
|
@ -82,7 +83,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, Path],
|
||||
) -> None:
|
||||
):
|
||||
save_model(
|
||||
config=self.config.to_dict(),
|
||||
state_dict=self.state_dict(),
|
||||
|
|
|
|||
|
|
@ -68,9 +68,6 @@ class EmbeddingEncoder(AutoModel):
|
|||
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Mapping, Optional
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -136,7 +136,7 @@ class AutoRegressiveLM(AutoModel):
|
|||
input_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
) -> Dict[str, Tensor]:
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
|
|
|||
|
|
@ -203,9 +203,45 @@ class DDPExecutor(BaseExecutor):
|
|||
|
||||
@ExecutorFactory.register("fsdp")
|
||||
class FSDPExecutor(BaseExecutor):
|
||||
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
grad_accum_steps: int = 1,
|
||||
process_group=None,
|
||||
sharding_strategy=None,
|
||||
cpu_offload=None,
|
||||
auto_wrap_policy=None,
|
||||
backward_prefetch=None,
|
||||
mixed_precision=None,
|
||||
ignored_modules=None,
|
||||
param_init_fn=None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = True,
|
||||
use_orig_params: bool = False,
|
||||
ignored_states=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||
self._fsdp_kwargs = fsdp_kwargs
|
||||
self._fsdp_kwargs = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=use_orig_params,
|
||||
ignored_states=ignored_states,
|
||||
device_mesh=device_mesh,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
self._original_model: Optional[nn.Module] = None
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import io
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
|
|
@ -11,11 +12,11 @@ import torch.distributed as dist
|
|||
from astrai.parallel.setup import get_rank
|
||||
|
||||
_META_FILE = "meta.json"
|
||||
_CONFIG_FILE = "config.json"
|
||||
_WEIGHTS_FILE = "model.safetensors"
|
||||
_MODEL_CONFIG_FILE = "config.json"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
||||
def save_safetensors(state_dict: dict, path: str | Path):
|
||||
st.save_file(state_dict, str(path))
|
||||
|
||||
|
||||
|
|
@ -23,7 +24,7 @@ def load_safetensors(path: str | Path) -> dict:
|
|||
return st.load_file(str(path))
|
||||
|
||||
|
||||
def save_json(data: dict, path: str | Path) -> None:
|
||||
def save_json(data: dict, path: str | Path):
|
||||
with open(str(path), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
|
@ -33,13 +34,92 @@ def load_json(path: str | Path) -> dict:
|
|||
return json.load(f)
|
||||
|
||||
|
||||
def save_torch(obj: Any, path: str | Path) -> None:
|
||||
def save_torch(obj: Any, path: str | Path):
|
||||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
def load_torch(path: str | Path) -> Any:
|
||||
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
|
||||
path = Path(path)
|
||||
rank = get_rank()
|
||||
|
||||
if rank == 0:
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
||||
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
||||
else:
|
||||
num_bytes = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
dist.broadcast(num_bytes, src=0)
|
||||
|
||||
if rank != 0:
|
||||
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
||||
|
||||
dist.broadcast(data_tensor, src=0)
|
||||
|
||||
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
||||
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str):
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def _get_meta(save_path: Path) -> dict:
|
||||
meta = {}
|
||||
if get_rank() == 0:
|
||||
meta = load_json(save_path / _META_FILE)
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
return meta
|
||||
|
||||
|
||||
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
specs: List[Tuple[str, List[int], str]] = [
|
||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||
for k in sorted(state_dict)
|
||||
]
|
||||
else:
|
||||
state_dict = {}
|
||||
specs = []
|
||||
|
||||
specs_list = [specs]
|
||||
dist.broadcast_object_list(specs_list, src=0)
|
||||
specs = specs_list[0]
|
||||
|
||||
for key, shape, dtype_name in specs:
|
||||
dtype = getattr(torch, dtype_name)
|
||||
if rank != 0:
|
||||
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
||||
else:
|
||||
tensor = state_dict[key].contiguous().cpu()
|
||||
dist.broadcast(tensor, src=0)
|
||||
if rank != 0:
|
||||
state_dict[key] = tensor
|
||||
return state_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
|
|
@ -49,7 +129,7 @@ class Checkpoint:
|
|||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def save(self, save_dir: str) -> None:
|
||||
def save(self, save_dir: str):
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -68,24 +148,16 @@ class Checkpoint:
|
|||
save_torch(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(cls, save_dir: str) -> "Checkpoint":
|
||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = {}
|
||||
if get_rank() == 0:
|
||||
meta = load_json(save_path / _META_FILE)
|
||||
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
meta = _get_meta(save_path)
|
||||
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
||||
|
||||
extra = {}
|
||||
for f in save_path.iterdir():
|
||||
for f in sorted(save_path.iterdir()):
|
||||
if f.suffix == ".pt":
|
||||
extra[f.stem] = load_torch(f)
|
||||
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
|
|
@ -93,18 +165,3 @@ class Checkpoint:
|
|||
iteration=meta.get("iteration", 0),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _MODEL_CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, strategy_cls: type) -> None:
|
||||
def _validate_component(cls, strategy_cls: type):
|
||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tqdm import tqdm
|
|||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.parallel.setup import get_current_device
|
||||
from astrai.parallel.setup import get_current_device, get_rank
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
|
|
@ -139,27 +139,27 @@ class CheckpointCallback(TrainCallback):
|
|||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_checkpoint(self, context: TrainContext):
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
# All ranks gather state_dict — collective for FSDP, local for DDP
|
||||
state_dict = (
|
||||
self.state_dict_fn(context.model)
|
||||
if self.state_dict_fn
|
||||
else context.model.state_dict()
|
||||
)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
if get_rank() == 0:
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
|
|
@ -168,13 +168,7 @@ class CheckpointCallback(TrainCallback):
|
|||
extra=extra,
|
||||
meta=context.config.to_dict(),
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
self.load_extra_fn(context.checkpoint.extra, context)
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
|
|
@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback):
|
|||
extra[name] = obj.state_dict()
|
||||
return extra
|
||||
|
||||
@staticmethod
|
||||
def load_extra(extra: dict, context: TrainContext):
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
if name in extra:
|
||||
getattr(context, name).load_state_dict(extra[name])
|
||||
|
||||
|
||||
@CallbackFactory.register("progress_bar")
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
|
|
@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora
|
|||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.serialization import Checkpoint, load_model_weights
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
||||
|
|
@ -42,10 +43,10 @@ class TrainContextBuilder:
|
|||
config: TrainConfig,
|
||||
):
|
||||
self.config = config
|
||||
self._checkpoint: Optional[Checkpoint] = None
|
||||
self._resume_dir: Optional[str] = None
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._checkpoint = checkpoint
|
||||
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
||||
self._resume_dir = resume_dir
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
|
|
@ -58,36 +59,40 @@ class TrainContextBuilder:
|
|||
**cfg.executor_kwargs,
|
||||
)
|
||||
|
||||
model = cfg.model_fn()
|
||||
model = model.to(device=device)
|
||||
|
||||
context = TrainContext(
|
||||
model=cfg.model,
|
||||
model=model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=cfg,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
context.model = context.model.to(device=device)
|
||||
|
||||
if self._checkpoint is not None:
|
||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||
if self._checkpoint.state_dict:
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
context.checkpoint = self._checkpoint
|
||||
if self._resume_dir is not None:
|
||||
resume_path = Path(self._resume_dir)
|
||||
if (resume_path / "meta.json").exists():
|
||||
checkpoint = Checkpoint.load(self._resume_dir)
|
||||
state_dict = checkpoint.state_dict
|
||||
else:
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=context.model.state_dict(),
|
||||
)
|
||||
checkpoint = None
|
||||
state_dict = load_model_weights(self._resume_dir)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if checkpoint is not None:
|
||||
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
|
||||
context.iteration = max(checkpoint.iteration, cfg.start_batch)
|
||||
context.checkpoint = checkpoint
|
||||
|
||||
if cfg.lora is not None:
|
||||
inject_lora(
|
||||
context.model,
|
||||
model,
|
||||
r=cfg.lora.r,
|
||||
alpha=cfg.lora.alpha,
|
||||
target_modules=set(cfg.lora.target_modules),
|
||||
)
|
||||
|
||||
context.optimizer = cfg.optimizer_fn(context.model)
|
||||
context.optimizer = cfg.optimizer_fn(model)
|
||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
|
|
@ -125,13 +130,21 @@ class TrainContextBuilder:
|
|||
|
||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||
executor.prepare(
|
||||
context.model,
|
||||
model,
|
||||
context.optimizer,
|
||||
context.dataloader,
|
||||
context.scheduler,
|
||||
)
|
||||
)
|
||||
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
extra = context.checkpoint.extra
|
||||
for name in ("optimizer", "scheduler"):
|
||||
if name in extra:
|
||||
obj = getattr(context, name, None)
|
||||
if obj is not None:
|
||||
obj.load_state_dict(extra[name])
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=cfg.strategy,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import List, Optional
|
|||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.train_callback import (
|
||||
CallbackFactory,
|
||||
TrainCallback,
|
||||
|
|
@ -54,9 +53,9 @@ class Trainer:
|
|||
if method:
|
||||
method(context)
|
||||
|
||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
||||
context = (
|
||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
||||
)
|
||||
executor = context.executor
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
|
@ -90,13 +89,13 @@ class Trainer:
|
|||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||
logger.error("Training failed: %s", str(e), exc_info=True)
|
||||
self._call_callbacks("on_error", context)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks("on_train_end", context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
def train(self, resume_dir: Optional[str] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._trainer_loop,
|
||||
|
|
@ -106,5 +105,5 @@ class Trainer:
|
|||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
checkpoint=checkpoint,
|
||||
resume_dir=resume_dir,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,279 @@
|
|||
"""MMLU evaluation via log-likelihood ranking."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
|
||||
MMLU_SUBJECTS = [
|
||||
"abstract_algebra",
|
||||
"anatomy",
|
||||
"astronomy",
|
||||
"business_ethics",
|
||||
"clinical_knowledge",
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_medicine",
|
||||
"college_physics",
|
||||
"computer_security",
|
||||
"conceptual_physics",
|
||||
"econometrics",
|
||||
"electrical_engineering",
|
||||
"elementary_mathematics",
|
||||
"formal_logic",
|
||||
"global_facts",
|
||||
"high_school_biology",
|
||||
"high_school_chemistry",
|
||||
"high_school_computer_science",
|
||||
"high_school_european_history",
|
||||
"high_school_geography",
|
||||
"high_school_government_and_politics",
|
||||
"high_school_macroeconomics",
|
||||
"high_school_mathematics",
|
||||
"high_school_microeconomics",
|
||||
"high_school_physics",
|
||||
"high_school_psychology",
|
||||
"high_school_statistics",
|
||||
"high_school_us_history",
|
||||
"high_school_world_history",
|
||||
"human_aging",
|
||||
"human_sexuality",
|
||||
"international_law",
|
||||
"jurisprudence",
|
||||
"logical_fallacies",
|
||||
"machine_learning",
|
||||
"management",
|
||||
"marketing",
|
||||
"medical_genetics",
|
||||
"miscellaneous",
|
||||
"moral_disputes",
|
||||
"moral_scenarios",
|
||||
"nutrition",
|
||||
"philosophy",
|
||||
"prehistory",
|
||||
"professional_accounting",
|
||||
"professional_law",
|
||||
"professional_medicine",
|
||||
"professional_psychology",
|
||||
"public_relations",
|
||||
"security_studies",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
"virology",
|
||||
"world_religions",
|
||||
]
|
||||
|
||||
|
||||
def _download_and_extract(url: str, data_dir: str):
|
||||
zip_path = os.path.join(data_dir, "mmlu.zip")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
print(f"Downloading MMLU data from {url}...")
|
||||
urllib.request.urlretrieve(url, zip_path)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(data_dir)
|
||||
os.remove(zip_path)
|
||||
|
||||
|
||||
def download_mmlu(data_dir: str):
|
||||
_download_and_extract(MMLU_URL, data_dir)
|
||||
src = os.path.join(data_dir, "test-master", "data")
|
||||
if os.path.exists(src):
|
||||
for item in os.listdir(src):
|
||||
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
|
||||
shutil.rmtree(os.path.join(data_dir, "test-master"))
|
||||
print(f"MMLU data saved to {data_dir}")
|
||||
|
||||
|
||||
def _strip_prefix(text: str, prefix: str) -> str:
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix) :].strip()
|
||||
return text
|
||||
|
||||
|
||||
def load_csv(path: str) -> list[dict]:
|
||||
data = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for row in csv.reader(f):
|
||||
if len(row) < 6:
|
||||
continue
|
||||
if row[0].strip().lower() == "question":
|
||||
continue
|
||||
data.append(
|
||||
{
|
||||
"question": row[0].strip(),
|
||||
"A": _strip_prefix(row[1].strip(), "A)"),
|
||||
"B": _strip_prefix(row[2].strip(), "B)"),
|
||||
"C": _strip_prefix(row[3].strip(), "C)"),
|
||||
"D": _strip_prefix(row[4].strip(), "D)"),
|
||||
"answer": row[5].strip(),
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def build_prompt(
|
||||
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
|
||||
) -> str:
|
||||
prompt = ""
|
||||
if n_shot > 0 and dev_data:
|
||||
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
|
||||
for item in dev_data[:n_shot]:
|
||||
prompt += f"Question: {item['question']}\n"
|
||||
for k in ("A", "B", "C", "D"):
|
||||
prompt += f"{k}. {item[k]}\n"
|
||||
prompt += f"Answer: {item['answer']}\n\n"
|
||||
prompt += f"Question: {question}\n"
|
||||
for k in ("A", "B", "C", "D"):
|
||||
prompt += f"{k}. {choices[k]}\n"
|
||||
prompt += "Answer:"
|
||||
return prompt
|
||||
|
||||
|
||||
def choice_logprob(
|
||||
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
|
||||
) -> float:
|
||||
choice_text = f" {choice_letter}"
|
||||
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
|
||||
input_ids = context_ids + choice_ids
|
||||
max_len = model.config.max_len
|
||||
if len(input_ids) > max_len:
|
||||
overflow = len(input_ids) - max_len
|
||||
input_ids = input_ids[overflow:]
|
||||
ctx_len = len(input_ids) - len(choice_ids)
|
||||
else:
|
||||
ctx_len = len(context_ids)
|
||||
|
||||
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
|
||||
with torch.inference_mode():
|
||||
logits = model(input_tensor)["logits"][0]
|
||||
|
||||
score = 0.0
|
||||
for i, tid in enumerate(choice_ids):
|
||||
pos = ctx_len - 1 + i
|
||||
if pos >= len(logits):
|
||||
break
|
||||
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
|
||||
return score
|
||||
|
||||
|
||||
def evaluate_subject(
|
||||
model,
|
||||
tokenizer,
|
||||
subject: str,
|
||||
test_data: list[dict],
|
||||
dev_data: list[dict] | None,
|
||||
device: str,
|
||||
n_shot: int,
|
||||
) -> tuple[float, int, int]:
|
||||
correct = 0
|
||||
total = 0
|
||||
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
|
||||
prompt = build_prompt(item["question"], item, subject, n_shot, dev_data or [])
|
||||
context_ids = tokenizer.encode(prompt)
|
||||
scores = {
|
||||
c: choice_logprob(model, tokenizer, context_ids, c, device)
|
||||
for c in ("A", "B", "C", "D")
|
||||
}
|
||||
if max(scores, key=scores.get) == item["answer"]:
|
||||
correct += 1
|
||||
total += 1
|
||||
return correct / total, correct, total
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
||||
parser.add_argument(
|
||||
"--param_path", type=str, default="./params", help="Model directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
||||
)
|
||||
parser.add_argument("--download", action="store_true", help="Download MMLU data")
|
||||
parser.add_argument(
|
||||
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output JSON path")
|
||||
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16" if torch.cuda.is_available() else "float32",
|
||||
help="Torch dtype",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.download or not os.path.exists(args.data_dir):
|
||||
download_mmlu(args.data_dir)
|
||||
|
||||
model = AutoModel.from_pretrained(args.param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
||||
device = args.device
|
||||
dtype = getattr(torch, args.dtype)
|
||||
model.to(device=device, dtype=dtype)
|
||||
|
||||
subjects = args.subjects or MMLU_SUBJECTS
|
||||
results = {}
|
||||
total_correct = 0
|
||||
total_questions = 0
|
||||
|
||||
for subject in subjects:
|
||||
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
|
||||
test_path = os.path.join(
|
||||
args.data_dir, args.split, f"{subject}_{args.split}.csv"
|
||||
)
|
||||
|
||||
if not os.path.exists(test_path):
|
||||
print(f" Skipping {subject}: test file not found")
|
||||
continue
|
||||
|
||||
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
|
||||
test_data = load_csv(test_path)
|
||||
|
||||
acc, corr, tot = evaluate_subject(
|
||||
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
|
||||
)
|
||||
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
|
||||
total_correct += corr
|
||||
total_questions += tot
|
||||
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
|
||||
|
||||
overall = total_correct / total_questions if total_questions else 0
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
|
||||
results["_overall"] = {
|
||||
"accuracy": round(overall, 4),
|
||||
"correct": total_correct,
|
||||
"total": total_questions,
|
||||
}
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"Results saved to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
|
|||
|
||||
|
||||
def process_file(
|
||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||
):
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
|
|
@ -44,8 +44,8 @@ def process_file(
|
|||
|
||||
for seq in batch_encoded:
|
||||
pad_len = max_len - len(seq)
|
||||
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
||||
mask = [False] * pad_len + [True] * len(seq)
|
||||
padded_seq = seq + [tokenizer.pad_id] * pad_len
|
||||
mask = [True] * len(seq) + [False] * pad_len
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ def process_file(
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||
parser.add_argument(
|
||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||
"--param_path", type=str, required=True, help="Path to the model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_file", type=str, required=True, help="Path to the input file."
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def main():
|
|||
"--reload", action="store_true", help="Enable auto-reload for development"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--param-path",
|
||||
"--param_path",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Path to model parameters (default: project_root/params)",
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import torch.optim as optim
|
|||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
|
|
@ -147,8 +146,8 @@ def parse_args() -> argparse.Namespace:
|
|||
"--parallel_mode",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["none", "ddp"],
|
||||
help="Parallel training strategy.",
|
||||
choices=["none", "ddp", "fsdp"],
|
||||
help="Parallel training strategy (none, ddp, fsdp).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||
|
|
@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace:
|
|||
return args
|
||||
|
||||
|
||||
def create_model(config):
|
||||
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||
|
||||
|
|
@ -228,6 +231,8 @@ def train(
|
|||
):
|
||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||
assert os.path.exists(param_path)
|
||||
if nprocs > 1 and parallel_mode == "none":
|
||||
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(param_path, "config.json")
|
||||
|
|
@ -236,15 +241,6 @@ def train(
|
|||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
||||
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
|
||||
checkpoint = Checkpoint.load(param_path)
|
||||
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||
model.load_state_dict(checkpoint.state_dict, strict=False)
|
||||
|
||||
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
|
||||
# (model weights already loaded into model above)
|
||||
checkpoint.state_dict = {}
|
||||
|
||||
strategy_kwargs = {
|
||||
"beta": dpo_beta,
|
||||
"label_smoothing": label_smoothing,
|
||||
|
|
@ -259,6 +255,7 @@ def train(
|
|||
"broadcast_buffers": False,
|
||||
}
|
||||
|
||||
model_fn = partial(create_model, config)
|
||||
dataset = DatasetFactory.load(
|
||||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
|
|
@ -290,7 +287,7 @@ def train(
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=model,
|
||||
model_fn=model_fn,
|
||||
strategy=train_type,
|
||||
dataset=dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
|
|
@ -315,7 +312,7 @@ def train(
|
|||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train(checkpoint=checkpoint)
|
||||
trainer.train(resume_dir=param_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -7,10 +7,8 @@ import torch
|
|||
|
||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||
from astrai.dataset.storage import (
|
||||
BaseSegmentFetcher,
|
||||
H5Storage,
|
||||
MultiSegmentFetcher,
|
||||
StorageFactory,
|
||||
H5Store,
|
||||
StoreFactory,
|
||||
detect_format,
|
||||
load_json,
|
||||
save_h5,
|
||||
|
|
@ -318,37 +316,48 @@ def test_unloaded_dataset_len():
|
|||
assert len(dataset) == 0
|
||||
|
||||
|
||||
def test_base_segment_fetcher_empty():
|
||||
"""BaseSegmentFetcher with empty segments list"""
|
||||
fetcher = BaseSegmentFetcher([])
|
||||
assert len(fetcher) == 0
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
fetcher.fetch_data(0, 1)
|
||||
def test_store_unloaded_len():
|
||||
"""Unloaded Store has __len__ == 0"""
|
||||
store = H5Store()
|
||||
assert len(store) == 0
|
||||
assert store.keys == []
|
||||
|
||||
|
||||
def test_base_segment_fetcher_begin_equals_end(base_test_env):
|
||||
"""fetch_data with begin == end returns empty tensor"""
|
||||
def test_store_fetch_begin_equals_end(base_test_env):
|
||||
"""Store.fetch with begin == end returns empty tensor"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
||||
save_h5(test_dir, "empty_fetch", dummy)
|
||||
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
||||
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
|
||||
result = fetcher.fetch_data(10, 10)
|
||||
result = dataset.storage.fetch(10, 10, "sequence")
|
||||
assert result.numel() == 0
|
||||
|
||||
|
||||
def test_multi_segment_fetcher_empty_dict():
|
||||
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
|
||||
fetcher = MultiSegmentFetcher({})
|
||||
assert len(fetcher) == 0
|
||||
def test_store_empty_data_len(base_test_env):
|
||||
"""Store loaded with empty data has __len__ == 0"""
|
||||
import os
|
||||
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "empty_store")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(data_dir, "data.json"), "w") as f:
|
||||
json.dump({"sequence": [[1, 2, 3]]}, f)
|
||||
|
||||
store = StoreFactory.create("json")
|
||||
store.load(data_dir)
|
||||
assert len(store) > 0
|
||||
|
||||
empty_store = H5Store()
|
||||
assert len(empty_store) == 0
|
||||
|
||||
|
||||
def test_storage_fetch_before_load():
|
||||
"""BaseStorage.fetch before load raises RuntimeError"""
|
||||
storage = H5Storage()
|
||||
def test_store_fetch_before_load():
|
||||
"""Store.fetch before load raises RuntimeError"""
|
||||
store = H5Store()
|
||||
with pytest.raises(RuntimeError, match="not loaded"):
|
||||
storage.fetch(0, 10, "sequence")
|
||||
store.fetch(0, 10, "sequence")
|
||||
|
||||
|
||||
def test_detect_format_nonexistent_path():
|
||||
|
|
@ -367,10 +376,10 @@ def test_detect_format_unsupported_file(base_test_env):
|
|||
detect_format(path)
|
||||
|
||||
|
||||
def test_create_storage_invalid_type():
|
||||
"""StorageFactory.create raises ValueError for unknown type"""
|
||||
def test_create_store_invalid_type():
|
||||
"""StoreFactory.create raises ValueError for unknown type"""
|
||||
with pytest.raises(ValueError, match="Unknown component"):
|
||||
StorageFactory.create("parquet")
|
||||
StoreFactory.create("parquet")
|
||||
|
||||
|
||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||
|
|
@ -407,14 +416,23 @@ def test_load_json_skips_config_file(base_test_env):
|
|||
assert len(result["sequence"]) == 1
|
||||
|
||||
|
||||
def test_base_segment_fetcher_multi_segment():
|
||||
"""fetch_data across multiple segment boundaries"""
|
||||
def test_store_multi_segment_concat(base_test_env):
|
||||
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
||||
import os
|
||||
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "multi_seg")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
segs = [
|
||||
torch.tensor([1, 2, 3]),
|
||||
torch.tensor([4, 5, 6, 7]),
|
||||
torch.tensor([8, 9]),
|
||||
]
|
||||
fetcher = BaseSegmentFetcher(segs)
|
||||
assert len(fetcher) == 9
|
||||
result = fetcher.fetch_data(2, 7)
|
||||
save_h5(data_dir, "data", {"sequence": segs})
|
||||
|
||||
store = StoreFactory.create("h5")
|
||||
store.load(data_dir)
|
||||
assert len(store) == 9
|
||||
result = store.fetch(2, 7, "sequence")
|
||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
|
|||
|
||||
|
||||
def create_train_config(
|
||||
model: torch.nn.Module,
|
||||
model_fn,
|
||||
dataset: Dataset,
|
||||
test_dir: str,
|
||||
device: str,
|
||||
|
|
@ -43,7 +43,7 @@ def create_train_config(
|
|||
"""Factory function to create common TrainConfig for tests.
|
||||
|
||||
Args:
|
||||
model: The model to train
|
||||
model_fn: Model factory (callable returning nn.Module)
|
||||
dataset: Training dataset
|
||||
test_dir: Checkpoint directory
|
||||
device: Device type ("cuda" or "cpu")
|
||||
|
|
@ -70,7 +70,7 @@ def create_train_config(
|
|||
|
||||
return TrainConfig(
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
model_fn=model_fn,
|
||||
dataset=dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
|
|
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
|
@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
strategy="seq",
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=early_stopping_dataset,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||
|
|
@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
trainer = Trainer(train_config)
|
||||
|
||||
# Should handle early stopping gracefully
|
||||
checkpoint = None
|
||||
try:
|
||||
checkpoint = trainer.train()
|
||||
trainer.train()
|
||||
except Exception:
|
||||
# Handle any exceptions
|
||||
pass
|
||||
|
||||
# Resume from latest checkpoint
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
trainer.train(checkpoint)
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train(resume_dir=load_dir)
|
||||
|
||||
# Verify checkpoint was saved at expected iteration
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
assert checkpoint.iteration == 10
|
||||
import json
|
||||
|
||||
with open(os.path.join(load_dir, "meta.json")) as f:
|
||||
meta = json.load(f)
|
||||
assert meta["iteration"] == 10
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
|||
|
||||
for batch_per_device in batch_sizes:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
|
|||
|
||||
for grad_accum_steps in grad_accum_steps_list:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
|||
|
||||
for config in small_batch_configs:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue