Compare commits

..

No commits in common. "main" and "v1.3.6" have entirely different histories.
main ... v1.3.6

73 changed files with 1551 additions and 6120 deletions

View File

@ -82,7 +82,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
@ -109,8 +108,8 @@ Full reference at [Parameter Guide](assets/docs/params.md).
```bash
python scripts/tools/generate.py \
--param_path /path/to/model \
--input_json_file /path/to/input.jsonl \
--output_json_file /path/to/output.jsonl
--input_json_file /path/to/input.json \
--output_json_file /path/to/output.json
```
#### Docker
@ -225,7 +224,6 @@ Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6y
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
### Contributing

View File

@ -88,7 +88,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
@ -115,8 +114,8 @@ nohup python scripts/tools/train.py \
```bash
python scripts/tools/generate.py \
--param_path /path/to/model \
--input_json_file /path/to/input.jsonl \
--output_json_file /path/to/output.jsonl
--input_json_file /path/to/input.json \
--output_json_file /path/to/output.json
```
#### Docker
@ -231,7 +230,6 @@ python scripts/demo/generate_ar.py
| [训练文档](./training.md) | 训练循环、策略与公式 |
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
### 贡献

View File

@ -8,8 +8,6 @@ classDiagram
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
+from_json(path) Self
+to_json(path)
}
class BaseModelConfig {
@ -19,43 +17,41 @@ classDiagram
}
class AutoRegressiveLMConfig {
+Optional[int] vocab_size
+Optional[int] dim
+Optional[int] n_layers
+Optional[float] norm_eps
+Optional[int] dim_ffn
+Optional[bool] tie_weight
+Optional[dict] rope_scaling
+Optional[int] max_len
+Optional[float] rope_theta
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+bool tie_weight
+int max_len
+float rope_theta
+str attn_type
+Optional[int] n_heads
+Optional[int] n_kv_heads
+Optional[bool] use_qk_norm
+Optional[bool] use_gated_attention
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+str ffn_type
+Optional[int] n_routed_experts
+Optional[int] n_shared_experts
+Optional[int] n_activated_experts
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+Optional[str] topk_method
}
class EncoderConfig {
+Optional[int] vocab_size
+Optional[int] dim
+Optional[int] n_layers
+Optional[float] norm_eps
+Optional[int] dim_ffn
+Optional[int] max_len
+Optional[float] rope_theta
+Optional[int] n_heads
+Optional[int] n_kv_heads
+Optional[bool] use_qk_norm
+Optional[bool] use_gated_attention
+Optional[dict] rope_scaling
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
@ -66,40 +62,8 @@ classDiagram
+load(raw) BaseConfig
}
class InputConfig {
+str type
+str messages_key
+str prompt_key
+str response_key
+str text_key
}
class ProcessingConfig {
+int max_seq_len
+int min_chars
+int max_chars
+bool deduplicate
+Optional[int] max_items
}
class OutputConfig {
+Optional[str] domain_key
+str storage_format
+int max_tokens_per_shard
}
class PipelineConfig {
+int version
+InputConfig input
+dict mask
+str mask_default
+ProcessingConfig preprocessing
+OutputConfig output
+from_dict(d) Self
}
class TrainConfig {
+Callable[[], nn.Module] model_fn
+nn.Module model
+str strategy
+Dataset dataset
+Callable optimizer_fn
@ -116,7 +80,6 @@ classDiagram
+str log_dir
+int log_interval
+List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
@ -125,12 +88,12 @@ classDiagram
+str backend
+str master_addr
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+str start_method
+str device_type
+Optional[Dataset] val_dataset
+int val_step
+str parallel_mode
+dict executor_kwargs
+dict extra_kwargs
+validate()
}
@ -141,8 +104,8 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+Optional[Store] storage
+load(load_path, storage_type)
+Optional[BaseStorage] storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
}
@ -163,25 +126,38 @@ classDiagram
+__getitem__(index) Dict
}
class Store {
+Dict[str, List[Tensor]] _data
+Dict[str, List[int]] _cum
+int _length
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
+keys (property)
+load(path)
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
-_fetch_key(key, begin, end) Tensor
-_normalize(raw)
}
class H5Store {
+load(path)
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MmapStore {
+List _mmap_refs
+load(path)
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
@ -189,17 +165,17 @@ classDiagram
+int iter
}
class StoreFactory {
class StorageFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) Store
+create(storage_type) BaseStorage
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
}
}
@ -210,9 +186,8 @@ classDiagram
+int iteration
+dict extra
+dict meta
+dict config
+save(save_dir)
+load(save_dir, broadcast) Checkpoint
+load(save_dir) Checkpoint
}
}
@ -220,8 +195,8 @@ classDiagram
class AutoModel {
+BaseModelConfig config
+Registry _registry
+register(name) decorator
+get_component_class(name) Type
+register(model_type) decorator
+get_component_class(model_type) Type
+from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
@ -235,7 +210,7 @@ classDiagram
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict, strict, assign)
+load_state_dict(state_dict)
+state_dict()
}
@ -260,7 +235,6 @@ classDiagram
}
class GQA {
+int dim
+int n_heads
+int n_kv_heads
+int head_dim
@ -275,7 +249,6 @@ classDiagram
}
class MLA {
+int dim
+int n_heads
+int n_kv_heads
+int head_dim
@ -284,13 +257,11 @@ classDiagram
+int qk_rope_head_dim
+int n_rep
+int layer_id
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+Linear gate # only if use_gated_attention
+RMSNorm kv_norm
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
@ -336,7 +307,6 @@ classDiagram
+int dim
+int max_len
+float base
+Optional[Dict] rope_scaling
+forward(x, position_ids=None) Tensor
}
@ -346,42 +316,13 @@ classDiagram
}
}
namespace preprocessing {
class BaseMaskBuilder {
<<abstract>>
+build(item, config, tokenizer) Optional[dict]
}
class ChatMaskBuilder {
+build(item, config, tokenizer) Optional[dict]
}
class InstructionMaskBuilder {
+build(item, config, tokenizer) Optional[dict]
}
class TextMaskBuilder {
+build(item, config, tokenizer) Optional[dict]
}
class Pipeline {
+PipelineConfig config
+List[str] paths
+str output_dir
+str tokenizer_path
+BaseMaskBuilder mask_builder
+transform(item) Optional[dict]
+run()
}
}
namespace tokenize {
class AutoTokenizer {
+vocab_size int
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List
+encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
@ -389,7 +330,7 @@ classDiagram
}
class ChatTemplate {
+str template_str
+String template_str
+render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate
}
@ -409,32 +350,24 @@ classDiagram
+create(name, *args, **kwargs) T
+list_registered() list
}
class MaskBuilderFactory {
+Registry _registry
+register(name) decorator
+create(input_type, config, tokenizer) BaseMaskBuilder
}
}
namespace trainer {
class Trainer {
+TrainConfig train_config
+List[TrainCallback] callbacks
+train(resume_dir)
-_get_default_callbacks() List[TrainCallback]
+train(checkpoint)
+_get_default_callbacks() List[TrainCallback]
}
class TrainContext {
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+OptimizerProtocol optimizer
+SchedulerProtocol scheduler
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+TrainConfig config
+dict model_config
+BaseExecutor executor
+int epoch
+int iteration
+float loss
@ -447,17 +380,13 @@ classDiagram
class TrainContextBuilder {
+TrainConfig config
+with_resume_dir(resume_dir) TrainContextBuilder
+with_checkpoint(checkpoint) TrainContextBuilder
+build() TrainContext
}
class BaseStrategy {
+Callable model
+Optional[BaseExecutor] executor
+Optional[Callable] model_fn
+dict extra_kwargs
+Union[Callable, nn.Module] model
+str device
+__call__(batch) Tensor
+compute_loss(batch) Tensor
}
@ -498,8 +427,6 @@ classDiagram
class BaseScheduler {
+get_lr() List[float]
+step()
+state_dict() dict
+load_state_dict(d)
}
class SchedulerFactory {
@ -511,7 +438,6 @@ classDiagram
class CosineScheduler {
+int warmup_steps
+int lr_decay_steps
+int total_steps
+float min_rate
}
@ -528,15 +454,16 @@ classDiagram
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_optimizer_step(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_optimizer_step(context)
+on_step_begin(context)
}
class GradientCheckpointingCallback {
@ -549,12 +476,16 @@ classDiagram
+str save_dir
+int interval
+bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn
-_save_checkpoint(context)
+Callable load_extra_fn
+_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
+save_extra(context) dict$
+save_extra(context)$
+load_extra(extra, context)$
}
class ProgressBarCallback {
@ -567,7 +498,7 @@ classDiagram
}
class MetricLoggerCallback {
+Path log_dir
+str log_dir
+int save_interval
+int log_interval
+List[str] metrics
@ -577,8 +508,8 @@ classDiagram
}
class ValidationCallback {
-_run_validation(context)
+on_optimizer_step(context)
+_run_validation(context)
+on_step_end(context)
}
class CallbackFactory {
@ -591,12 +522,7 @@ classDiagram
+float lr
+float momentum
+float weight_decay
+bool nesterov
+int ns_steps
+Optional[float] adamw_lr
+tuple adamw_betas
+float adamw_eps
+float adamw_wd
+step(closure) Optional[float]
}
}
@ -617,8 +543,6 @@ classDiagram
+AutoModel model
+AutoTokenizer tokenizer
+KVCache page_cache
+Optional[str] device
+Optional[torch.dtype] dtype
+execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int]
}
@ -630,9 +554,7 @@ classDiagram
+bool _running
+Thread _loop_thread
+int max_seq_len
+str device
+torch.dtype dtype
+add_task(prompt, **kwargs) str
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
+stop()
@ -710,7 +632,7 @@ classDiagram
class Task {
+str task_id
+List prompt_ids
+Optional[int] max_tokens
+int max_tokens
+float temperature
+float top_p
+int top_k
@ -719,8 +641,8 @@ classDiagram
+int input_tokens
+int output_tokens
+float arrival_time
+Optional[float] finish_time
+Optional[Callable] stream_callback
+float finish_time
+Callable stream_callback
+int next_pos
+is_finished(stop_ids) bool
}
@ -735,24 +657,15 @@ classDiagram
class TaskManager {
+AutoTokenizer tokenizer
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+Deque waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+add_task(prompt, **kwargs) str
+remove_task(task_id) List[Task]
+remove_finished_tasks(stop_ids) List[Task]
+pull_candidates(n) List[Task]
+activate(task)
+return_to_waiting(tasks)
+get_active_tasks() List[Task]
+has_work() bool
+wait_for_tasks(timeout)
+get_waiting_tasks() List[Task]
+clear_queues()
+wake()
+get_stats() Dict
}
class GenerationRequest {
@ -831,65 +744,56 @@ classDiagram
+str model
+List[AnthropicMessage] messages
+Optional[str] system
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+float temperature
+float top_p
+int top_k
+int max_tokens
+Optional[bool] stream
+bool stream
+Optional[List[str]] stop_sequences
}
class ResponseBuilder {
<<abstract>>
+prepare(request, tokenizer) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, tokenizer) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, tokenizer) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler {
<<abstract>>
+request
+engine
+builder: ResponseBuilder
+async handle() Union[StreamingResponse, Dict]
-_handle_stream(agen, ctx, stop_sequences) StreamingResponse
-async _handle_non_stream(agen, ctx, stop_sequences) Dict
+build_prompt() str
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict]
}
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
}
class StopChecker {
+__init__(sequences)
+has_sequences (property) bool
+check(text) Optional[str]
+trim(text, matched) str
}
class GenContext {
class StreamContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
}
class StopInfo {
+Optional[str] matched
+str body
+str yielded
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
}
class app {
@ -898,87 +802,15 @@ classDiagram
}
}
namespace protocols {
class OptimizerProtocol {
<<protocol>>
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class SchedulerProtocol {
<<protocol>>
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
}
namespace parallel {
class setup {
class Functions {
<<module>>
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) contextmanager
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
+only_on_rank(rank, sync=False) decorator
}
class GradientState {
+int num_steps
+sync_gradients (property) bool
}
class AccumOptimizer {
+Optimizer optimizer
+GradientState gradient_state
+param_groups (property)
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class AccumScheduler {
+LRScheduler scheduler
+GradientState gradient_state
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
class BaseExecutor {
+GradientState gradient_state
+prepare(model, optimizer, dataloader, scheduler) tuple
+accumulate(model) context manager
+backward(loss)
+unwrap_model(model) dict
+sync_gradients (property) bool
+grad_accum_steps (property) int
}
class NoneExecutor {
}
class DDPExecutor {
-_prepare_model(model) nn.Module
-_no_sync(model) context manager
+unwrap_model(model) dict
}
class FSDPExecutor {
-_prepare_model(model) nn.Module
+unwrap_model(model) dict
}
class ExecutorFactory {
+Registry _registry
+register(name) decorator
+create(parallel_mode, **kwargs) BaseExecutor
+only_on_rank(rank, sync) decorator
}
class ParallelModel {
@ -988,25 +820,11 @@ classDiagram
}
class ColumnParallelLinear {
+int in_features
+int out_features
+int out_features_per_rank
+bool gather_results
+Parameter weight
+Optional[Parameter] bias
+forward(x) Tensor
+load_state_dict(state_dict)
}
class RowParallelLinear {
+int in_features
+int out_features
+int in_features_per_rank
+bool reduce_results
+Parameter weight
+Optional[Parameter] bias
+forward(x) Tensor
+load_state_dict(state_dict)
}
}
@ -1024,13 +842,12 @@ classDiagram
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
Store <|-- H5Store
Store <|-- MmapStore
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
@ -1041,10 +858,6 @@ classDiagram
AutoModel <|-- EmbeddingEncoder
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseConfig <|-- InputConfig
BaseConfig <|-- ProcessingConfig
BaseConfig <|-- OutputConfig
BaseConfig <|-- PipelineConfig
BaseModelConfig <|-- AutoRegressiveLMConfig
BaseModelConfig <|-- EncoderConfig
BaseFactory <|-- AutoModel
@ -1054,23 +867,18 @@ classDiagram
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StoreFactory
BaseFactory <|-- ExecutorFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- ConfigFactory
BaseFactory <|-- MaskBuilderFactory
BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor
BaseExecutor <|-- FSDPExecutor
ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder
BaseMaskBuilder <|-- ChatMaskBuilder
BaseMaskBuilder <|-- InstructionMaskBuilder
BaseMaskBuilder <|-- TextMaskBuilder
TrainCallback <|-- ValidationCallback
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
@ -1084,31 +892,21 @@ classDiagram
DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
AccumScheduler o-- GradientState
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
TrainContext o-- BaseExecutor
KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- Store
Pipeline o-- PipelineConfig
Pipeline o-- BaseMaskBuilder
BaseDataset o-- BaseStorage
%% --- Dependency (uses temporarily) ---
TrainConfig ..> BaseStrategy : selects
PipelineConfig ..> MaskBuilderFactory : selects
MaskBuilderFactory ..> BaseMaskBuilder : creates
StrategyFactory ..> BaseStrategy : creates
SchedulerFactory ..> BaseScheduler : creates
DatasetFactory ..> BaseDataset : creates
@ -1119,14 +917,10 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StoreFactory ..> H5Store : creates
StoreFactory ..> MmapStore : creates
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
Trainer ..> Functions : spawns
@ -1137,10 +931,10 @@ classDiagram
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
AnthropicResponseBuilder ..> MessagesRequest : receives
OpenAIHandler ..> ChatCompletionRequest : receives
AnthropicHandler ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> GenContext : creates
ProtocolHandler ..> StreamContext : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
@ -1153,6 +947,8 @@ classDiagram
Executor --> AutoModel
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
```
@ -1161,48 +957,43 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, ChatMaskBuilder, InstructionMaskBuilder, TextMaskBuilder, Pipeline, filter_by_length, dedup_signature | Declarative JSON-driven data preprocessing |
| **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
## Design Patterns
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
| **Builder** | `TrainContextBuilder` | Chain-building training context |
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor`, `FSDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
## Core Relationships
1. **Config → Training**: `TrainConfig` holds `model_fn`, `dataset`, `optimizer_fn`, `scheduler_fn`, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-30
> Document Update Time: 2026-05-17

View File

@ -5,21 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview
```
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
```
## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
StoreFactory.create("h5") → H5Store
StoreFactory.create("bin") → MmapStore
StorageFactory.create("h5") → H5Storage
StorageFactory.create("json") → JSONStorage
```
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
Both support shared memory via `.share_memory_()`.
## Data Keys by Training Type
@ -33,21 +33,14 @@ H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS pag
## Dataset Architecture
```
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
→ BaseDataset.load(load_path, storage_type=None)
→ detect_format(load_path)
→ StoreFactory.create(storage_type)
→ Store.load(load_path)
→ H5Store._normalize() / MmapStore._normalize()
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
DatasetFactory.load(train_type, path, window_size, stride)
→ StorageFactory.create(detect_format(path))
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx)
→ get_index(idx) → [begin, end)
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
→ sliding window [begin, end) via get_index(idx)
```
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
`window_size` = max input length, `stride` = step between consecutive samples.
## Sampler
@ -61,4 +54,4 @@ DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_typ
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-30
> Document Update Time: 2026-05-17

View File

@ -12,16 +12,16 @@ RoPE is applied **before** KV cache write, not after — otherwise position enco
## KVCache System
Six classes (plus two helpers) working together:
Six classes working together:
```
KVCache (facade)
├── PagePool orchestrates page allocation + prefix matching
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── Allocator bitmask-based page allocator + ref-count + LRU eviction
├── PrefixCache hash-based prefix matching (page_hash via rolling hash)
├── PagePool orchestrates Allocator + PrefixCache
├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
└── KvcacheView bundles Storage + page_table + total_len for attention layers
```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
@ -40,33 +40,26 @@ KVCache (facade)
## Sampling (Strategy Pattern)
```
BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
├── TopPStrategy
└── SamplingPipeline
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Strategy Pattern)
## Protocol Handlers (Template Method)
```python
class ProtocolHandler: # concrete orchestrator
def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine)
class ProtocolHandler(ABC):
def handle(self):
ctx = StreamContext(...)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops)
else: return await self._handle_non_stream(agen, ctx, stops)
if stream: self._handle_stream(agen, ctx)
else: self._handle_non_stream(agen, ctx)
```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
`OpenAIHandler``/v1/chat/completions`, `AnthropicHandler``/v1/messages`.
## Engine & GenerateResult
@ -74,9 +67,7 @@ Adding a protocol = one builder file, no handler subclassing needed.
InferenceEngine
├── generate(prompt, stream, ...) → str | List[str] | Generator
├── generate_with_request(req) → same
├── generate_async(prompt, ...) → AsyncGenerator
├── get_stats() → Dict
└── shutdown()
└── generate_async(prompt, ...) → AsyncGenerator
```
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
@ -103,14 +94,12 @@ Response:
{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
}
```
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]`
### Anthropic
@ -127,10 +116,10 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `top_k` | int | 50 | Top-k count |
| `temperature` | float | 1.0 | Sampling temperature (0.02.0) |
| `top_p` | float | 1.0 | Nucleus threshold |
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
| `max_tokens` | Optional[int] | None | Max generation length |
| `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length |
| `stream` | bool | False | Stream output |
## Engine API
@ -145,8 +134,7 @@ engine.generate("Hello", stream=True) # -> Generator[str]
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
# Async
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
print(token)
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
```
> Document Update Time: 2026-05-30
> Document Update Time: 2026-05-17

View File

@ -53,9 +53,7 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
### Strategy-specific
@ -75,7 +73,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
@ -97,4 +94,4 @@ nohup python scripts/tools/train.py \
---
> Document Update Time: 2026-05-24
> Document Update Time: 2026-05-17

View File

@ -1,346 +0,0 @@
# Preprocessing Pipeline
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
## Philosophy
| Component | Responsibility |
|-----------|---------------|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
A single config file captures the entire pipeline, reusable and version-controllable.
## Config Structure
```json
{
"input": {}, // sections (single) or sources (multi)
"mask": {}, // role → "train" | "mask"
"mask_default": "mask",
"preprocessing": {},
"output": {}
}
```
### Section Fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `field` | str | -- | JSONL key to read |
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
| `template` | bool | `false` | Apply `chat_template` per message |
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
### Source Fields (multi-output mode)
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] | -- | Same as single-output section list |
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
---
## Quick Start
### SFT Chat
Input JSONL:
```json
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
```
Config:
```json
{
"input": {
"sections": [
{"field": "messages", "action": "$role", "template": true}
]
},
"mask": {
"system": "mask",
"user": "mask",
"assistant": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
},
"output": {
"storage_format": "bin",
"dtype": {"loss_mask": "bool"}
}
}
```
Output keys: `sequence` (int32), `loss_mask` (bool)
### SFT Instruction
Input JSONL:
```json
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
```
Config:
```json
{
"input": {
"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
}
}
```
Output keys: `sequence`, `loss_mask`
### Pretrain
Input JSONL:
```json
{"text": "Artificial Intelligence is a field of computer science..."}
```
Config:
```json
{
"input": {
"sections": [
{"field": "text", "action": "train"}
]
},
"preprocessing": {
"max_seq_len": 8192,
"min_chars": 100
}
}
```
Output keys: `sequence` (no `loss_mask` — all tokens trained)
### DPO
Input JSONL:
```json
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
```
Config:
```json
{
"input": {
"sources": {
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": true}
]
},
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": true}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
### GRPO
Input JSONL:
```json
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
```
Config:
```json
{
"input": {
"sources": {
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": true}
]
},
"responses": {
"sections": [
{"field": "responses", "action": "train"}
],
"list_field": true,
"mask_key": "masks"
},
"rewards": {
"sections": [
{"field": "rewards", "action": "value"}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
- `action: "value"` — extract raw values from JSONL without tokenisation
- `list_field: true` — tokenise each list element independently, then concatenate
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
---
## Configuration Reference
### `input`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
When `sources` is set, `sections` is ignored.
### `mask`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
### `preprocessing`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
| `max_items` | int or null | `null` | Stop after N documents |
### `output`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
---
## Mask Algorithm
### Template mode (`template: true`)
For each message in the field's array:
1. Prepend BOS token (masked)
2. Render through `chat_template` for that single message
3. Encode rendered text
4. Apply mask rule for the message's role
### Non-template mode
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
### Text config detection
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
---
## Output Layout
### Single-Shard (`bin`)
```
output/
__default__/
meta.json
sequence.bin
loss_mask.bin
wiki/
meta.json
sequence.bin
loss_mask.bin
```
### Multi-Shard (`bin`)
When `max_tokens_per_shard` is exceeded:
```
output/
__default__/
shard_0000/
meta.json
sequence.bin
loss_mask.bin
shard_0001/
meta.json
sequence.bin
loss_mask.bin
```
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
---
## CLI
```bash
# SFT
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
# DPO
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
# GRPO
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
```
---
## Python API
```python
from astrai.preprocessing.pipeline import Pipeline
from astrai.config.preprocess_config import PipelineConfig
config = PipelineConfig.from_json("sft.json")
Pipeline(
config,
["data_part1.jsonl", "data_part2.jsonl"],
output_dir="output/",
tokenizer_path="params",
).run()
```
> Document Update Time: 2026-06-03

View File

@ -1,5 +1,38 @@
# Training
## Model Architecture
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, ChineseEnglish bilingual.
```mermaid
flowchart TB
subgraph Layers["Transformer Layers"]
direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer ...]
D --> E[RMSNorm]
E --> F[Linear]
F --> G[SoftMax]
end
subgraph TransformerBlock["Transformer Block"]
direction TB
H[x] --> I[RMSNorm]
I --> J[Linear → Q/K/V]
J --> K[Q]; J --> L[K]; J --> M[V]
K --> N[RoPE]; L --> O[RoPE]
N --> P["Q @ K^T / sqrt(d)"]; O --> P
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
R --> S[Linear]; S --> T[+]; H --> T
T --> U[RMSNorm]
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
V --> X[SiLU]; X --> Y[×]; W --> Y
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
AA --> BB[x']
end
```
### Autoregression
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
@ -36,23 +69,19 @@ Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_
```
on_train_begin
model.train()
on_epoch_begin
for batch in dataloader:
on_batch_begin
with executor.accumulate(model):
loss = strategy.compute_loss(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
loss = strategy(batch)
(loss / grad_accum_steps).backward()
iteration += 1
on_batch_end
if executor.sync_gradients:
on_optimizer_step
if iteration % grad_accum_steps == 0:
on_step_begin
optimizer.step()
optimizer.zero_grad()
if scheduler:
on_step_end
scheduler.step()
on_epoch_end
on_train_end
@ -63,15 +92,12 @@ on_train_end
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
| `on_batch_begin` | Every batch | — |
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
| `on_step_end` | Every accumulation window | `ValidationCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies
@ -83,7 +109,7 @@ $$
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
Keys: `input_ids`, `target_ids`
### SFT (Supervised Fine-Tuning)
@ -93,7 +119,7 @@ $$
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
Keys: `input_ids`, `target_ids`, `loss_mask`
### DPO (Direct Preference Optimization)
@ -103,7 +129,7 @@ $$
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
$$
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
### GRPO (Group Relative Policy Optimization)
@ -117,7 +143,7 @@ $$
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
$$
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`.
Keys: `prompts`, `responses`, `masks`, `rewards`.
@ -128,7 +154,7 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Gradient Checkpointing
@ -144,30 +170,29 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
└── load(save_dir) broadcasts metadata from rank-0
```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern)
```python
context = (
TrainContextBuilder(config)
.with_resume_dir(resume_dir)
.with_checkpoint(checkpoint)
.build()
)
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
```
- Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Wraps model with `parallel_wrapper` if `nprocs > 1`
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
- Builds strategy via `StrategyFactory.create(train_type, ...)`
## Training CLI
@ -176,7 +201,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
@ -198,4 +222,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-30
> Document Update Time: 2026-05-17

View File

@ -1,4 +1,4 @@
__version__ = "1.3.7"
__version__ = "1.3.6"
__author__ = "ViperEkura"
from astrai.config import (

View File

@ -4,22 +4,13 @@ from astrai.config.model_config import (
ConfigFactory,
EncoderConfig,
)
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.config.train_config import TrainConfig
__all__ = [
# Model configuration
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ConfigFactory",
"TrainConfig",
"InputConfig",
"OutputConfig",
"PipelineConfig",
"ProcessingConfig",
]

View File

@ -1,7 +1,6 @@
import json
from dataclasses import MISSING, dataclass, fields
from pathlib import Path
from typing import Any, Dict, Optional, Self, Union, get_type_hints
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
@ -14,21 +13,12 @@ class BaseConfig:
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, (dict, list, tuple)):
elif isinstance(v, (dict, list)):
try:
val = list(v) if isinstance(v, tuple) else v
json.dumps(val)
d[fld.name] = val
json.dumps(v)
d[fld.name] = v
except (TypeError, ValueError):
pass
elif isinstance(v, BaseConfig):
d[fld.name] = v.to_dict()
elif hasattr(v, "__dataclass_fields__"):
sub = {}
for f in fields(v):
a = getattr(v, f.name)
sub[f.name] = list(a) if isinstance(a, tuple) else a
d[fld.name] = sub
return d
@classmethod
@ -84,15 +74,4 @@ class BaseConfig:
return value
if isinstance(value, target_type):
return value
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
return target_type.from_dict(value)
raise TypeError
@classmethod
def from_json(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -49,7 +49,6 @@ class AutoRegressiveLMConfig(BaseModelConfig):
max_len: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
attn_type: str = "gqa"
n_heads: Optional[int] = None
@ -81,7 +80,6 @@ class EncoderConfig(BaseModelConfig):
max_len: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None

View File

@ -1,109 +0,0 @@
"""Pipeline configuration for JSONL preprocessing.
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
modes, both driven declaratively through ``input.sections`` or
``input.sources``.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
"""Declarative input mapping.
Single-output mode (backward-compatible)::
{"input": {"sections": [{"field": "messages", ...}]}}
Multi-output mode (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [{"field": "chosen", ...}]},
"rejected": {"sections": [{"field": "rejected", ...}]},
}}}
"""
sections: Optional[List[Dict]] = None
sources: Optional[Dict[str, Dict]] = None
@dataclass
class ProcessingConfig(BaseConfig):
"""Processing configuration.
Parameters
----------
max_seq_len : int
Maximum sequence length (default: 2048).
min_chars : int
Minimum number of characters to keep (default: 50).
max_chars : int
Maximum number of characters to keep (default: 2_000_000).
max_items : Optional[int]
Maximum number of items to process (default: None, unlimited).
packing_strategy : str
How to pack sequences into a contiguous stream.
- ``"simple"``: sequential concatenation (default, backward compatible).
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
max_packed_len : int
Maximum length of a packed bin. Sequences longer than this are
truncated or split depending on ``packing_strategy`` (default: 8192).
truncation_mode : str
How to truncate sequences longer than ``max_packed_len``.
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
"""
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
max_items: Optional[int] = None
packing_strategy: str = "simple"
max_packed_len: int = 8192
truncation_mode: str = "keep_start"
@dataclass
class OutputConfig(BaseConfig):
"""Output configuration.
Parameters
----------
domain_key : Optional[str]
Domain key for the output store (default: None).
storage_format : str
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
max_tokens_per_shard : int
Maximum tokens per shard before splitting (default: 100_000_000).
dtype : Dict[str, str]
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
position_ids_mode : Optional[str]
How to compute position_ids in packed sequences.
- ``None`` / ``"none"``: do not generate (backward compatible).
- ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
"""
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
position_ids_mode: Optional[str] = None
@dataclass
class PipelineConfig(BaseConfig):
version: int = 1
input: InputConfig = field(default_factory=InputConfig)
mask: Dict[str, str] = field(default_factory=dict)
mask_default: str = "mask"
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
output: OutputConfig = field(default_factory=OutputConfig)

View File

@ -7,7 +7,6 @@ from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
from astrai.model.components.lora import LoRAConfig
def required(**kw):
@ -17,8 +16,8 @@ def required(**kw):
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model factory for training.")
model: nn.Module = field(
default=None, metadata=required(help="Model for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(
@ -57,12 +56,6 @@ class TrainConfig(BaseConfig):
default=5000, metadata={"help": "Number of iterations between checkpoints."}
)
# lora setting
lora: Optional[LoRAConfig] = field(
default=None,
metadata={"help": "LoRA config. None means full fine-tuning."},
)
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
@ -102,9 +95,11 @@ class TrainConfig(BaseConfig):
master_port: str = field(
default="29500", metadata={"help": "Master port for distributed training."}
)
parallel_mode: str = field(
default="none",
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
parallel_wrapper: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for training."}
)
state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."}
)
start_method: str = field(
default="spawn",
@ -118,21 +113,11 @@ class TrainConfig(BaseConfig):
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_split: Optional[float] = field(
default=None,
metadata={
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
},
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
executor_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
)
extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."}
)

View File

@ -4,28 +4,32 @@ from astrai.dataset.dataset import (
)
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
H5Store,
MmapStore,
Store,
StoreFactory,
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MultiSegmentFetcher,
StorageFactory,
detect_format,
load_bin,
load_h5,
save_bin,
load_json,
save_h5,
save_json,
)
__all__ = [
"BaseDataset",
"DatasetFactory",
"Store",
"StoreFactory",
"H5Store",
"MmapStore",
"BaseSegmentFetcher",
"MultiSegmentFetcher",
"BaseStorage",
"H5Storage",
"JSONStorage",
"StorageFactory",
"detect_format",
"save_h5",
"load_h5",
"save_bin",
"load_bin",
"save_json",
"load_json",
"ResumableDistributedSampler",
]

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset
from astrai.dataset.storage import (
Store,
StoreFactory,
BaseStorage,
StorageFactory,
detect_format,
)
from astrai.factory import BaseFactory
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
super().__init__()
self.window_size = window_size
self.stride = stride
self.storage: Optional[Store] = None
self.storage: Optional[BaseStorage] = None
@property
def required_keys(self) -> List[str]:
@ -48,26 +48,37 @@ class BaseDataset(Dataset, ABC):
f"Missing: {missing}"
)
def load(self, load_path: str, storage_type: Optional[str] = None):
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
"""Load dataset from the given path.
Auto-detects the storage format if not specified.
Args:
load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "bin"),
storage_type: Force a specific storage type ("h5", "json"),
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
Raises:
KeyError: If the loaded storage is missing required keys.
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StoreFactory.create(storage_type)
self.storage = StorageFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path)
self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys()
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
@property
def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset."""
@ -137,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
"""
@classmethod
def _validate_component(cls, dataset_cls: type):
def _validate_component(cls, dataset_cls: type) -> None:
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
@ -164,6 +175,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
window_size: int,
stride: Optional[int] = None,
storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset":
"""Create and load a dataset in one step.
@ -172,7 +184,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "bin") or None for auto-detection
storage_type: Storage type ("h5", "json") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
Returns:
Loaded dataset instance
@ -181,7 +194,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type)
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
return dataset
@ -223,7 +236,7 @@ class SFTDataset(BaseDataset):
@property
def required_keys(self) -> List[str]:
return ["sequence", "loss_mask", "position_ids"]
return ["sequence", "loss_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -231,17 +244,15 @@ class SFTDataset(BaseDataset):
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence")
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
dtype=torch.long
)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return {
"input_ids": x.to(dtype=torch.long),
"target_ids": y.to(dtype=torch.long),
"position_ids": position_ids.to(dtype=torch.long),
"loss_mask": loss_mask.to(dtype=torch.bool),
}
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
@DatasetFactory.register("dpo")
@ -295,11 +306,9 @@ class GRPODataset(BaseDataset):
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index)
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
dtype=torch.long
)
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
responses = self._fetch_data(begin_idx, end_idx, "responses")
masks = self._fetch_data(begin_idx, end_idx, "masks")
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {

View File

@ -43,7 +43,6 @@ class ResumableDistributedSampler(Sampler[int]):
offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas
self.iter = self.iter % self.num_samples_per_replica
self._indices = None
@ -75,10 +74,5 @@ class ResumableDistributedSampler(Sampler[int]):
self.epoch += 1
self._indices = None
@property
def _remaining(self):
remaining = self.num_samples_per_replica - self.iter
return max(remaining, 0)
def __len__(self):
return self._remaining
return self.num_samples_per_replica

View File

@ -1,32 +1,17 @@
"""Storage backends for different data formats.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
"""
import bisect
import glob
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union
from typing import Callable, Dict, List, Optional, Union
import h5py
import numpy as np
import torch
from torch import Tensor
@ -69,30 +54,54 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
with open(os.path.join(file_path, "meta.json"), "w") as f:
json.dump(meta, f)
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
with open(os.path.join(file_path, "meta.json"), "r") as f:
meta = json.load(f)
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r+",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def detect_format(load_path: str) -> str:
@ -102,7 +111,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path
Returns:
Format string ("h5" or "bin")
Format string ("h5" or "json")
Raises:
FileNotFoundError: If no supported data files are found
@ -112,160 +121,181 @@ def detect_format(load_path: str) -> str:
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = [
Path(p)
for pattern in ("*.h5", "*.hdf5")
for p in glob.glob(str(root / "**" / pattern), recursive=True)
]
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
if bin_files:
has_meta = (root / "meta.json").exists() or len(
[Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
) > 0
if has_meta:
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
"""
def __init__(self):
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
self._fetcher: Optional[MultiSegmentFetcher] = None
@abstractmethod
def load(self, path: str) -> None:
def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = (
min((cum[-1] if cum else 0) for cum in self._cum.values())
if self._cum
else 0
)
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@StoreFactory.register("h5")
class H5Store(Store):
@StorageFactory.register("h5")
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str):
self._normalize(load_h5(path))
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
@StorageFactory.register("json")
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, path: str):
self._mmap_refs = []
root = Path(path)
all_raw: Dict[str, List[Tensor]] = {}
meta_paths = [
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
]
for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items():
if key not in all_raw:
all_raw[key] = []
all_raw[key].extend(tensors)
if not meta_paths:
raise FileNotFoundError(f"No meta.json found under {path}")
self._normalize(all_raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)

View File

@ -23,7 +23,7 @@ class Registry:
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
):
) -> None:
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]):
def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.

View File

@ -2,26 +2,24 @@
Layers:
- core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP orchestration (ProtocolHandler, server)
- protocols/: Response builders (OpenAI, Anthropic)
- transport/: SSE transport utilities
- api/: HTTP protocol handlers (OpenAI, Anthropic)
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
"""
from astrai.inference.api import (
AnthropicHandler,
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
GenContext,
MessagesRequest,
OpenAIHandler,
ProtocolHandler,
StopChecker,
get_app,
StreamContext,
app,
run_server,
)
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.core import (
STOP,
Allocator,
@ -38,7 +36,10 @@ from astrai.inference.core import (
TaskTable,
page_hash,
)
from astrai.inference.engine import GenerationRequest, InferenceEngine
from astrai.inference.engine import (
GenerationRequest,
InferenceEngine,
)
from astrai.inference.sample import (
BaseSamplingStrategy,
SamplingPipeline,
@ -49,14 +50,17 @@ from astrai.inference.sample import (
)
__all__ = [
# Engine / Requests
"InferenceEngine",
"GenerationRequest",
# Core scheduler
"InferenceScheduler",
"Executor",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
# Core cache
"Allocator",
"KVCache",
"KvcacheView",
@ -65,21 +69,24 @@ __all__ = [
"Storage",
"TaskTable",
"page_hash",
# Sampling (Strategy pattern)
"sample",
"BaseSamplingStrategy",
"TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
# Protocol
"ProtocolHandler",
"StopChecker",
"GenContext",
"OpenAIResponseBuilder",
"AnthropicResponseBuilder",
"StreamContext",
"AnthropicHandler",
"OpenAIHandler",
# Server
"ChatMessage",
"ChatCompletionRequest",
"AnthropicMessage",
"MessagesRequest",
"get_app",
"app",
"run_server",
]

View File

@ -1,27 +1,31 @@
"""Inference API: protocol handler, stop checker, and FastAPI server.
"""Inference API: protocol handlers and FastAPI server."""
``app`` is no longer a module-level global. Use :func:`get_app` to access the
lazy singleton FastAPI instance.
"""
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
from astrai.inference.api.protocol import (
AnthropicHandler,
OpenAIHandler,
ProtocolHandler,
StopChecker,
StreamContext,
)
from astrai.inference.api.server import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
get_app,
app,
run_server,
)
__all__ = [
"AnthropicHandler",
"OpenAIHandler",
"ProtocolHandler",
"StopChecker",
"GenContext",
"StreamContext",
"AnthropicMessage",
"ChatCompletionRequest",
"ChatMessage",
"MessagesRequest",
"get_app",
"app",
"run_server",
]

View File

@ -1,141 +0,0 @@
"""Anthropic message completion response builder."""
import time
import uuid
from typing import Any, Dict, List, Tuple, Union
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class AnthropicResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages: List[Dict[str, str]] = []
system = getattr(request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in request.messages:
text = _extract_text(m.content)
if text:
messages.append({"role": m.role, "content": text})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
ctx = GenContext(
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=int(time.time()),
model=request.model,
prompt_tokens=0,
)
stop_sequences = getattr(request, "stop_sequences", None) or []
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
events: List[str] = []
if stop.matched:
trimmed = stop.body[: stop.body.rfind(stop.matched)]
unyielded = trimmed[len(stop.yielded) :]
if unyielded:
events.append(
sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": unyielded},
},
event="content_block_delta",
)
)
events.append(
sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
if stop.matched:
content = content[: content.rfind(stop.matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -1,140 +0,0 @@
"""OpenAI chat completion response builder."""
import logging
import time
import uuid
from typing import Any, Dict, List, Tuple
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__)
_UNSUPPORTED_PARAMS = (
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
)
class OpenAIResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages = [{"role": m.role, "content": m.content} for m in request.messages]
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model
for param in _UNSUPPORTED_PARAMS:
value = getattr(request, param, None)
fields = getattr(type(request), "model_fields", {})
default = fields[param].default if param in fields else None
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
ctx = GenContext(
resp_id=self._resp_id,
created=int(time.time()),
model=self._model,
prompt_tokens=0,
)
stop = request.stop
stop_sequences = (
[] if stop is None else [stop] if isinstance(stop, str) else stop
)
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}

View File

@ -1,13 +1,15 @@
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
ProtocolHandler orchestrates the async generation loop and delegates
protocol-specific formatting to a ResponseBuilder.
Template Method + Builder patterns eliminate the 45% code duplication between
stream/non-stream branches and across protocol adapters.
"""
import json
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
@ -15,7 +17,7 @@ from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
@ -24,28 +26,22 @@ def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
return "\n".join(lines)
def sse_done() -> str:
def _sse_done() -> str:
return "data: [DONE]\n\n"
@dataclass
class GenContext:
"""Per-generation metadata passed to builder format methods."""
class StreamContext:
"""Shared state across the streaming generation lifecycle."""
resp_id: str
created: int
model: str
prompt_tokens: int
completion_tokens: int = 0
@dataclass
class StopInfo:
"""Stop-check result passed to format_stream_end / format_response."""
matched: Optional[str] = None
body: str = ""
yielded: str = ""
accumulated: str = ""
stop_matched: Optional[str] = None
last_yield_trimmed: str = ""
class StopChecker:
@ -60,60 +56,95 @@ class StopChecker:
return seq
return None
def trim(self, text: str, matched: str) -> str:
idx = text.rfind(matched)
return text[:idx] if idx != -1 else text
class ResponseBuilder(ABC):
"""Interface for protocol-specific response formatting.
@property
def has_sequences(self) -> bool:
return len(self._sequences) > 0
A new protocol requires one concrete builder implementing 5 methods.
class ProtocolHandler(ABC):
"""Template-method base for API protocol handlers.
Subclasses implement format hooks; the base class orchestrates the
generate-async loop and SSE/JSON response construction.
Lifecycle::
handle()
build_prompt() # protocol-specific prompt assembly
create_response_id() # unique response identifier
[stream]
format_stream_start()
format_stream_token() × N
on_token() hook for stop-sequence interception
format_stream_end()
[non-stream]
(accumulate tokens)
format_non_stream_response()
"""
@abstractmethod
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
"""Return (prompt, ctx, stop_sequences) for a generation request."""
request_model: type[BaseModel]
@abstractmethod
def format_stream_start(self, ctx: GenContext) -> List[str]:
"""SSE events that open the stream."""
@abstractmethod
def format_chunk(self, token: str) -> str:
"""SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
"""SSE events that close the stream."""
@abstractmethod
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
"""JSON response body for non-streaming mode."""
class ProtocolHandler:
"""Orchestrates the generation loop, delegates formatting to a builder.
Usage::
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
response = await handler.handle()
"""
def __init__(
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
):
def __init__(self, request: BaseModel, engine: InferenceEngine):
self.request = request
self.engine = engine
self.builder = builder
@abstractmethod
def build_prompt(self) -> str:
"""Build the full prompt string from the request messages."""
@abstractmethod
def create_response_id(self) -> str:
"""Generate a unique response ID following the protocol convention."""
@abstractmethod
def format_stream_start(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that open the stream (role marker, metadata)."""
@abstractmethod
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
"""Yield an SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that close the stream (finish reason, usage stats)."""
@abstractmethod
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
"""Build the JSON response body for non-streaming mode."""
def get_stop_sequences(self) -> List[str]:
return []
def create_stop_checker(self) -> StopChecker:
return StopChecker(self.get_stop_sequences())
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
"""Hook after each token is appended to accumulated.
Return a matched stop-sequence string to break the loop,
or None to continue.
"""
return None
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
ctx = StreamContext(
resp_id=self.create_response_id(),
created=int(time.time()),
model=self.request.model,
prompt_tokens=self._count_prompt_tokens(),
)
agen = self.engine.generate_async(
prompt=prompt,
prompt=self.build_prompt(),
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
@ -121,37 +152,33 @@ class ProtocolHandler:
)
if self.request.stream:
return self._handle_stream(agen, ctx, stop_sequences)
return self._handle_stream(agen, ctx)
else:
return await self._handle_non_stream(agen, ctx, stop_sequences)
return await self._handle_non_stream(agen, ctx)
def _handle_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> StreamingResponse:
checker = StopChecker(stop_sequences)
def _count_prompt_tokens(self) -> int:
return len(self.engine.tokenizer.encode(self.build_prompt()))
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
stop_checker = self.create_stop_checker()
async def event_stream():
for event in self.builder.format_stream_start(ctx):
for event in self.format_stream_start(ctx):
yield event
body = ""
yielded = ""
matched = None
async for token in agen:
body += token
ctx.completion_tokens += 1
ctx.accumulated += token
matched = checker.check(body)
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
ctx.completion_tokens += 1
yield self.builder.format_chunk(token)
yielded += token
yield self.format_stream_token(ctx, token)
stop = StopInfo(matched=matched, body=body, yielded=yielded)
for event in self.builder.format_stream_end(ctx, stop):
for event in self.format_stream_end(ctx):
yield event
yield sse_done()
yield _sse_done()
return StreamingResponse(
event_stream(),
@ -159,24 +186,260 @@ class ProtocolHandler:
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
async def _handle_non_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> Dict[str, Any]:
checker = StopChecker(stop_sequences)
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
stop_checker = self.create_stop_checker()
chunks: List[str] = []
body = ""
matched = None
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
chunks.append(token)
body += token
matched = checker.check(body)
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
ctx.completion_tokens += 1
content = "".join(chunks)
stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop)
return self.format_non_stream_response(ctx, content)
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract plain text from an Anthropic content block (string or list)."""
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class OpenAIHandler(ProtocolHandler):
"""OpenAI-compatible /v1/chat/completions handler."""
def build_prompt(self) -> str:
messages = [
{"role": m.role, "content": m.content} for m in self.request.messages
]
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def get_stop_sequences(self) -> List[str]:
stop = self.request.stop
if stop is None:
return []
return [stop] if isinstance(stop, str) else stop
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
return stop_checker.check(ctx.accumulated)
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return _sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
_sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
return {
"id": ctx.resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
class AnthropicHandler(ProtocolHandler):
"""Anthropic-compatible /v1/messages handler."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._yielded = ""
def build_prompt(self) -> str:
messages: List[Dict[str, str]] = []
system = getattr(self.request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in self.request.messages:
content = _extract_text_content(m.content)
if content:
messages.append({"role": m.role, "content": content})
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"msg_{uuid.uuid4().hex[:24]}"
def get_stop_sequences(self) -> List[str]:
return getattr(self.request, "stop_sequences", None) or []
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
matched = stop_checker.check(ctx.accumulated)
if not matched:
return None
ctx.stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :]
if unyielded:
ctx.last_yield_trimmed = unyielded
return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
_sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token
return _sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = ctx.stop_matched
events: List[str] = []
last_yielded = ctx.last_yield_trimmed
if last_yielded:
events.append(
_sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": last_yielded},
},
event="content_block_delta",
)
)
events.append(
_sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
_sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
matched = ctx.stop_matched
if matched:
content = content[: content.rfind(matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -3,9 +3,6 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
``app`` is lazily constructed importing this module does NOT create a FastAPI instance.
Use :func:`get_app` to access the singleton.
"""
import logging
@ -15,19 +12,17 @@ from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import ProtocolHandler
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_app_instance: Optional[FastAPI] = None
_project_root = Path(__file__).parent.parent.parent
class ChatMessage(BaseModel):
@ -87,15 +82,17 @@ async def lifespan(app: FastAPI):
logger.info("Inference engine shutdown complete")
router = APIRouter()
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _create_engine(
param_path: Path,
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
@ -113,66 +110,49 @@ def _create_engine(
return engine
def get_app() -> FastAPI:
"""Return the singleton FastAPI instance (lazily created on first call)."""
global _app_instance
if _app_instance is None:
_app_instance = FastAPI(
title="AstrAI Inference Server",
version="0.2.0",
lifespan=lifespan,
)
_app_instance.include_router(router)
_app_instance.state.server_config = {}
_app_instance.state.engine = None
return _app_instance
def _get_engine() -> InferenceEngine:
engine = get_app().state.engine
engine = app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@router.get("/health")
@app.get("/health")
async def health():
app = get_app()
return {
"status": "ok",
"model_loaded": app.state.engine is not None,
}
@router.get("/stats")
@app.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@router.post("/v1/chat/completions")
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
handler = OpenAIHandler(request, engine)
return await handler.handle()
@router.post("/v1/messages")
@app.post("/v1/messages")
async def create_message(request: MessagesRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
handler = AnthropicHandler(request, engine)
return await handler.handle()
def run_server(
param_path: Path,
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app = get_app()
app.state.server_config = {
"device": device,
"dtype": dtype,

View File

@ -42,7 +42,7 @@ class Allocator:
return idx
return -1
def free(self, idx: int, keep_cached: bool = False):
def free(self, idx: int, keep_cached: bool = False) -> None:
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int):
def inc_ref(self, idx: int) -> None:
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock:
return self._refs[idx]
def touch(self, idx: int):
def touch(self, idx: int) -> None:
with self._lock:
self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int):
def evict(self, idx: int) -> None:
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
@ -96,7 +96,9 @@ class PrefixCache:
hits.append(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
@ -125,13 +127,13 @@ class PagePool:
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int):
def free(self, idx: int) -> None:
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int):
def inc_ref(self, idx: int) -> None:
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
@ -140,7 +142,9 @@ class PagePool:
self._alloc.touch(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -153,7 +157,7 @@ class TaskTable:
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int):
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
@ -216,7 +220,7 @@ class Storage:
start_pos: int,
k: Tensor,
v: Tensor,
):
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
@ -282,7 +286,7 @@ class KvcacheView:
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor):
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -335,7 +339,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str):
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
@ -355,7 +359,7 @@ class KVCache:
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
):
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):

View File

@ -29,7 +29,9 @@ class Executor:
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len:
return

View File

@ -71,19 +71,18 @@ class InferenceScheduler:
)
self._running = False
self._fatal_error: Optional[Exception] = None
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str):
def remove_task(self, task_id: str) -> None:
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self):
def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
@ -109,10 +108,7 @@ class InferenceScheduler:
continue
to_prefill = [
t
for t in self._task_mgr.get_active_tasks()
if t.output_tokens == 0
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
]
if to_prefill:
for t in to_prefill:
@ -160,15 +156,11 @@ class InferenceScheduler:
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
extend_ok = self._page_cache.task_extend(t.task_id, pos)
self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
if not extend_ok:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
for t in valid:
if t.is_finished(stop_ids):
@ -176,37 +168,28 @@ class InferenceScheduler:
t.stream_callback(STOP)
except Exception as e:
self._fatal_error = e
self._running = False
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
raise
def start(self):
def start(self) -> None:
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self):
def stop(self) -> None:
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft())
return to_add
def activate(self, task: Task):
def activate(self, task: Task) -> None:
task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]):
def return_to_waiting(self, tasks: List[Task]) -> None:
with self._lock:
for task in reversed(tasks):
self.waiting_queue.appendleft(task)
@ -185,10 +185,7 @@ class TaskManager:
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0):
with self._lock:
if self.waiting_queue or self.active_tasks:
return
def wait_for_tasks(self, timeout: float = 1.0) -> None:
self._task_event.clear()
self._task_event.wait(timeout=timeout)
@ -196,14 +193,10 @@ class TaskManager:
with self._lock:
return list(self.active_tasks)
def get_waiting_tasks(self) -> List[Task]:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self):
def clear_queues(self) -> None:
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self):
def wake(self) -> None:
self._task_event.set()

View File

@ -13,6 +13,17 @@ from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer
def _validate_sampling_params(
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
@ -48,7 +59,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0):
def wait_completion(self, timeout: float = 300.0) -> None:
with self._cond:
if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
@ -75,12 +86,7 @@ class GenerationRequest:
max_tokens: Optional[int] = None,
stream: bool = False,
):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature > 0):
raise ValueError("temperature must be a positive number")
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
self.messages = messages
self.top_k = top_k
@ -131,6 +137,7 @@ class InferenceEngine:
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
@ -151,6 +158,7 @@ class InferenceEngine:
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
@ -281,7 +289,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats()
def shutdown(self):
def shutdown(self) -> None:
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -44,12 +44,10 @@ class TemperatureStrategy(BaseSamplingStrategy):
def apply(self, logits, filter_value=-float("inf")):
t = self.temperature
if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
t = torch.clamp(t, min=1e-8)
if (t != 1.0).any():
logits = logits / t
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / max(t, 1e-8)
logits = logits / t
return logits

View File

@ -2,13 +2,6 @@ from astrai.model.automodel import AutoModel
from astrai.model.components.attention import GQA
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.lora import (
LoRAConfig,
inject_lora,
load_lora,
merge_lora,
save_lora,
)
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.encoder import EmbeddingEncoder
@ -25,10 +18,4 @@ __all__ = [
"AutoRegressiveLM",
"EmbeddingEncoder",
"AutoModel",
# LoRA
"LoRAConfig",
"inject_lora",
"merge_lora",
"save_lora",
"load_lora",
]

View File

@ -2,24 +2,21 @@
AutoModel base class for model loading and saving.
"""
import json
from contextlib import contextmanager
from pathlib import Path
from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.factory import BaseFactory
from astrai.serialization import load_model_config, load_model_weights, save_model
@contextmanager
def _disable_random_init(enable: bool = True):
if not enable:
yield
return
names = (
init_functions = [
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
@ -29,15 +26,18 @@ def _disable_random_init(enable: bool = True):
"constant_",
"normal_",
"uniform_",
)
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
for n in orig:
setattr(nn.init, n, lambda *a, **kw: None)
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try:
yield
finally:
for n, fn in orig.items():
setattr(nn.init, n, fn)
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
class AutoModel(BaseFactory["AutoModel"], nn.Module):
@ -60,22 +60,25 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path)
# Load config
config_path = model_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
raw = load_model_config(str(model_path))
if config_path.exists():
with open(config_path, "r") as f:
raw = json.load(f)
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)
# Load weights
weights_path = model_path / "model.safetensors"
if weights_path.exists():
state_dict = load_model_weights(str(model_path))
state_dict = st.load_file(str(weights_path))
model.load_state_dict(state_dict, strict=strict)
return model
@ -83,12 +86,15 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained(
self,
save_directory: Union[str, Path],
):
save_model(
config=self.config.to_dict(),
state_dict=self.state_dict(),
save_directory=str(save_directory),
)
) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
# Save config
self.config.to_file(str(save_path / "config.json"))
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
def to(self, *args, **kwargs) -> Self:
"""Move model to device/dtype."""

View File

@ -1,194 +0,0 @@
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional, Set
import torch
import torch.nn as nn
import torch.nn.functional as F
from astrai.model.components.linear import Linear
from astrai.serialization import (
load_json,
load_safetensors,
save_json,
save_safetensors,
)
logger = logging.getLogger(__name__)
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
TARGET_MODULES_FFN = {"up", "gate", "down"}
@dataclass
class LoRAConfig:
r: int = 16
alpha: int = 32
target_modules: tuple = ("q_proj", "v_proj")
class LoRALinear(nn.Module):
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
super().__init__()
self.register_parameter("weight", base.weight)
self.weight.requires_grad_(False)
self.bias = base.bias
if self.bias is not None:
self.bias.requires_grad_(False)
self.r = r
self.scaling = alpha / r
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
self._merged = False
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
if not self._merged:
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
return out
def merge(self):
if self._merged:
return
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self._merged = True
del self.lora_A
del self.lora_B
def _collect_lora_info(model: nn.Module) -> dict:
names = {}
for n, m in model.named_modules():
if isinstance(m, Linear):
_, _, child = n.rpartition(".")
names.setdefault(child, []).append(n)
return names
def _get_lora_count(model: nn.Module) -> int:
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
def inject_lora(
model: nn.Module,
r: int = 16,
alpha: int = 32,
target_modules: Optional[Set[str]] = None,
) -> LoRAConfig:
if target_modules is None:
target_modules = TARGET_MODULES_ATTN
available = _collect_lora_info(model)
injected = 0
for name, module in list(model.named_modules()):
if not isinstance(module, Linear):
continue
parent_name, _, child_name = name.rpartition(".")
if child_name not in target_modules:
continue
parent = model.get_submodule(parent_name) if parent_name else model
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
injected += 1
if injected == 0:
logger.warning(
"No LoRA layers injected. Available Linear child names: %s. "
"target_modules: %s. Check model type and target_modules.",
sorted(available),
sorted(target_modules),
)
else:
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
def merge_lora(model: nn.Module):
n = 0
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
n += 1
if n == 0:
logger.warning("No LoRA layers to merge.")
else:
logger.info("Merged %d LoRA layers", n)
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
lora_sd = {
k: v
for k, v in model.state_dict().items()
if k.endswith((".lora_A", ".lora_B"))
}
if not lora_sd:
raise RuntimeError(
"No LoRA parameters found in model. "
"The model may not have been injected or was already merged."
)
path = Path(save_dir)
path.mkdir(parents=True, exist_ok=True)
save_safetensors(lora_sd, path / "adapter_model.safetensors")
save_json(asdict(config), path / "adapter_config.json")
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
path = Path(load_dir)
raw = load_json(path / "adapter_config.json")
config = LoRAConfig(
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
)
existing = _get_lora_count(model)
if existing > 0:
logger.warning(
"Model already has %d LoRA layers. Skipping injection, "
"loading weights onto existing layers only.",
existing,
)
else:
inject_lora(
model,
r=config.r,
alpha=config.alpha,
target_modules=set(config.target_modules),
)
weights = load_safetensors(path / "adapter_model.safetensors")
try:
missing, unexpected = model.load_state_dict(weights, strict=False)
except RuntimeError as e:
msg = str(e)
if "size mismatch" in msg:
raise RuntimeError(
f"LoRA weight shapes do not match the model. "
f"The adapter config (r={config.r}) may not match the injected layers. "
f"Original error: {msg}"
) from e
raise
injected = _get_lora_count(model)
if injected == 0:
raise RuntimeError(
"No LoRA layers found after loading. "
"Inject LoRA before calling load_lora, or check the adapter config."
)
if missing:
lora_missing = [k for k in missing if "lora" in k]
if lora_missing:
raise RuntimeError(
f"LoRA weight keys not found in model: {lora_missing}. "
f"The adapter config (r={config.r}) may not match the model."
)
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
if unexpected:
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
logger.info("LoRA adapter loaded from %s", load_dir)
return config

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Optional
import torch
import torch.nn as nn
@ -19,10 +19,6 @@ def get_rotary_emb(
return torch.complex(cos, sin)
def ntk_base(base: float, dim: int, factor: float) -> float:
return base * (factor ** (dim / (dim - 2)))
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
@ -34,25 +30,11 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim: int,
max_len: int,
base: float = 10000,
rope_scaling: Optional[Dict] = None,
):
def __init__(self, dim: int, max_len: int, base: float = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self.rope_scaling = rope_scaling
if rope_scaling is not None:
scaling_type = rope_scaling.get("type", "ntk")
factor = rope_scaling.get("factor", 1.0)
if scaling_type == "ntk":
self.base = ntk_base(base, dim, factor)
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):

View File

@ -20,9 +20,7 @@ class EmbeddingEncoder(AutoModel):
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
@ -68,6 +66,9 @@ class EmbeddingEncoder(AutoModel):
x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Mapping, Optional
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
@ -26,21 +26,24 @@ def process_attention_mask(
return input_mask
device = input_tensor.device
B = input_tensor.size(0)
dtype = input_tensor.dtype
B, S = input_tensor.size()[:2]
T = position_ids.max().item() + 1
if input_mask is None:
if position_ids.min().item() == 0 and is_causal:
return None
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
pad = torch.ones(B, T, dtype=torch.bool, device=device)
else:
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
attend = pad.view(B, 1, T).expand(B, S, T).clone()
if is_causal:
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
attend = attend & causal
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
return attend.unsqueeze(1)
return torch.full(
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("autoregressive_lm")
@ -56,9 +59,7 @@ class AutoRegressiveLM(AutoModel):
else config.dim // config.n_heads
)
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
@ -133,7 +134,7 @@ class AutoRegressiveLM(AutoModel):
input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
) -> Tensor:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)

View File

@ -1,13 +1,3 @@
from astrai.parallel.executor import (
AccumOptimizer,
AccumScheduler,
BaseExecutor,
DDPExecutor,
ExecutorFactory,
FSDPExecutor,
GradientState,
NoneExecutor,
)
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import (
get_current_device,
@ -27,12 +17,4 @@ __all__ = [
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear",
"ExecutorFactory",
"BaseExecutor",
"GradientState",
"AccumOptimizer",
"AccumScheduler",
"NoneExecutor",
"DDPExecutor",
"FSDPExecutor",
]

View File

@ -1,272 +0,0 @@
"""Unified training executor — parallel strategy + gradient accumulation."""
import contextlib
import logging
import os
from contextlib import contextmanager
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.factory import BaseFactory
from astrai.parallel.setup import get_rank, get_world_size
logger = logging.getLogger(__name__)
class GradientState:
def __init__(self, grad_accum_steps: int = 1):
self.num_steps = max(grad_accum_steps, 1)
self._step: int = 0
self._sync_gradients: bool = True
@property
def sync_gradients(self) -> bool:
return self._sync_gradients
def _do_sync(self):
self._step += 1
self._sync_gradients = self._step % self.num_steps == 0
class AccumOptimizer:
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
self.optimizer = optimizer
self.gradient_state = gradient_state
def step(self, closure=None):
if self.gradient_state.sync_gradients:
self.optimizer.step(closure)
def zero_grad(self):
if self.gradient_state.sync_gradients:
self.optimizer.zero_grad()
@property
def param_groups(self):
return self.optimizer.param_groups
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, d):
self.optimizer.load_state_dict(d)
class AccumScheduler:
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
self.scheduler = scheduler
self.gradient_state = gradient_state
def step(self):
if self.gradient_state.sync_gradients:
self.scheduler.step()
def state_dict(self):
return self.scheduler.state_dict()
def load_state_dict(self, d):
self.scheduler.load_state_dict(d)
def get_last_lr(self):
return self.scheduler.get_last_lr()
class BaseExecutor:
def __init__(self, grad_accum_steps: int = 1):
self.gradient_state = GradientState(grad_accum_steps)
def prepare(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
dataloader: Optional[DataLoader] = None,
scheduler: Optional[LRScheduler] = None,
) -> Tuple[
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
]:
model = self._prepare_model(model)
if optimizer is not None:
optimizer = AccumOptimizer(optimizer, self.gradient_state)
if scheduler is not None:
scheduler = AccumScheduler(scheduler, self.gradient_state)
return model, optimizer, dataloader, scheduler
def _prepare_model(self, model: nn.Module) -> nn.Module:
return model
def _no_sync(self, model: nn.Module):
return contextlib.nullcontext()
@contextmanager
def accumulate(self, model: nn.Module):
self.gradient_state._do_sync()
if not self.gradient_state.sync_gradients:
with self._no_sync(model):
yield
else:
yield
def backward(self, loss: torch.Tensor):
loss.backward()
def unwrap_model(self, model: nn.Module):
return model.state_dict()
@property
def use_distributed(self) -> bool:
return get_world_size() > 1
@property
def sync_gradients(self) -> bool:
return self.gradient_state.sync_gradients
@property
def grad_accum_steps(self) -> int:
return self.gradient_state.num_steps
class ExecutorFactory(BaseFactory[BaseExecutor]):
pass
@ExecutorFactory.register("none")
class NoneExecutor(BaseExecutor):
pass
@ExecutorFactory.register("ddp")
class DDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
dim: int = 0,
broadcast_buffers: bool = True,
init_sync: bool = True,
process_group=None,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
delay_all_reduce_named_params=None,
param_to_hook_all_reduce=None,
mixed_precision=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._ddp_kwargs = dict(
dim=dim,
broadcast_buffers=broadcast_buffers,
init_sync=init_sync,
process_group=process_group,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
delay_all_reduce_named_params=delay_all_reduce_named_params,
param_to_hook_all_reduce=param_to_hook_all_reduce,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
)
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("DDP backend selected but world_size=1, model not wrapped")
return model
local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
**self._ddp_kwargs,
)
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, DDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=True,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
return model
self._original_model = model
device_id = torch.device("cuda", get_rank())
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, FSDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, FSDP) and self.use_distributed:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
return model.state_dict()
return model.state_dict()

View File

@ -1,5 +1,4 @@
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import wraps
from typing import Callable
@ -31,7 +30,6 @@ def get_rank() -> int:
def setup_parallel(
rank: int,
world_size: int,
local_rank: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
@ -43,18 +41,14 @@ def setup_parallel(
return
if world_size <= 1:
device_id = torch.device(device_type, local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
yield None
return
device_id = torch.device(device_type, local_rank)
device_id = torch.device(device_type, rank)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id)
@ -96,7 +90,7 @@ def only_on_rank(rank, sync=False):
return decorator
def _run_single_rank(
def wrapper_spawn_func(
rank: int,
world_size: int,
backend: str,
@ -106,10 +100,10 @@ def _run_single_rank(
func: Callable,
kwargs: dict,
):
try:
with setup_parallel(
rank=rank,
world_size=world_size,
local_rank=rank,
backend=backend,
master_addr=master_addr,
master_port=master_port,
@ -117,99 +111,11 @@ def _run_single_rank(
):
func(**kwargs)
class LaunchStrategy(ABC):
"""Strategy for launching a function in a distributed context."""
def __init__(
self,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
device_type: str,
start_method: str,
):
self.world_size = world_size
self.backend = backend
self.master_addr = master_addr
self.master_port = master_port
self.device_type = device_type
self.start_method = start_method
@abstractmethod
def launch(self, func: Callable, **kwargs):
raise NotImplementedError
class TorchrunStrategy(LaunchStrategy):
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
def launch(self, func: Callable, **kwargs):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
with setup_parallel(
rank=rank,
world_size=world_size,
local_rank=local_rank,
backend=self.backend,
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
master_port=os.environ.get("MASTER_PORT", self.master_port),
device_type=self.device_type,
):
func(**kwargs)
class LocalStrategy(LaunchStrategy):
"""Local launcher — single-process or mp.start_processes."""
def launch(self, func: Callable, **kwargs):
args = (
self.world_size,
self.backend,
self.master_addr,
self.master_port,
self.device_type,
func,
kwargs,
)
if self.world_size == 1:
_run_single_rank(0, *args)
return
ctx = mp.start_processes(
_run_single_rank,
args=args,
nprocs=self.world_size,
start_method=self.start_method,
join=False,
)
try:
while not ctx.join():
pass
except BaseException:
for p in ctx.processes:
p.terminate()
ctx.join()
except Exception as e:
print(f"Error in rank {rank}: {e}")
raise
def _detect_launcher() -> str:
"""Detect the distributed launcher from environment.
Returns one of: "torchelastic", "torchrun", "external", "local".
"""
if dist.is_torchelastic_launched():
return "torchelastic"
if "LOCAL_WORLD_SIZE" in os.environ:
return "torchrun"
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
return "external"
return "local"
def spawn_parallel_fn(
func: Callable,
world_size: int,
@ -220,13 +126,41 @@ def spawn_parallel_fn(
start_method: str = "spawn",
**kwargs,
):
launcher = _detect_launcher()
if launcher in ("torchelastic", "torchrun", "external"):
strategy = TorchrunStrategy(
world_size, backend, master_addr, master_port, device_type, start_method
# clear environment variables
for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]
if world_size == 1:
device_id = torch.device(device_type, 0)
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return
wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
func,
kwargs,
)
else:
strategy = LocalStrategy(
world_size, backend, master_addr, master_port, device_type, start_method
mp.start_processes(
wrapper_spawn_func,
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
)
strategy.launch(func, **kwargs)

View File

@ -1,14 +0,0 @@
from astrai.preprocessing.builder import (
BaseMaskBuilder,
MaskBuilderFactory,
SectionedMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
__all__ = [
"BaseMaskBuilder",
"MaskBuilderFactory",
"SectionedMaskBuilder",
"Pipeline",
"filter_by_length",
]

View File

@ -1,338 +0,0 @@
"""Mask building strategies for preprocessing pipeline.
The single :class:`SectionedMaskBuilder` handles all input formats
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
for single-output or ``input.sources`` for multi-output.
"""
from abc import ABC, abstractmethod
from typing import Optional
from astrai.factory import BaseFactory
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
Returns ``None`` to skip the item entirely.
"""
...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
val = item.get(domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
def _resolve_action(action: str, role: str, config) -> str:
"""Resolve action to "train" or "mask".
- ``"train"`` / ``"mask"`` literal
- ``"$role"`` look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
"""
if action == "$role":
return config.mask.get(role, config.mask_default)
return action
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder supporting single and multi-output modes.
Single-output (backward-compatible)::
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
{"sequence": [...], "loss_mask": [...], "domain": "..."}
Multi-output (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [
{"field": "chosen", "action": "$role", "template": true}
]},
"rejected": {"sections": [
{"field": "rejected", "action": "$role", "template": true}
]}
}}}
{"chosen": [...], "chosen_mask": [...],
"rejected": [...], "rejected_mask": [...], "domain": "..."}
Output spec fields::
sections list of section specs (same format as single-output)
list_field True when the JSONL field holds a list of values to
tokenise individually and concatenate (GRPO responses)
mask_key explicit output key for the loss mask
(default: ``"{output_key}_mask"``)
dtype explicit tensor dtype for this output key
(default: "int32")
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
sources_spec = getattr(config.input, "sources", None)
if sources_spec:
return self._build_multi(item, sources_spec, config, tokenizer)
return self._build_single(item, config, tokenizer)
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
sections = config.input.sections
if not sections:
return None
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
return None
result: dict = {
"sequence": ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in mask):
result["loss_mask"] = mask
return result
def _build_multi(
self, item: dict, sources_spec: dict, config, tokenizer
) -> Optional[dict]:
result: dict = {}
any_output = False
for output_key, spec in sources_spec.items():
sections = spec.get("sections", [])
if not sections:
continue
if self._is_value_section(sections):
ids = self._extract_raw_value(item, sections)
if ids is None:
continue
result[output_key] = ids
any_output = True
continue
list_field = spec.get("list_field", False)
mask_key = spec.get("mask_key", f"{output_key}_mask")
if list_field:
ids, mask = self._process_list_field(item, sections, config, tokenizer)
else:
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
continue
result[output_key] = ids
if not all(m == 1 for m in mask):
result[mask_key] = mask
elif "mask_key" in spec:
result[mask_key] = mask
any_output = True
if not any_output:
return None
result["domain"] = _extract_domain(item, config.output.domain_key)
return result
@staticmethod
def _is_value_section(sections: list) -> bool:
return len(sections) == 1 and sections[0].get("action") == "value"
@staticmethod
def _extract_raw_value(item: dict, sections: list):
"""Extract a raw value from a JSONL field without tokenisation.
Used for GRPO rewards where the field contains float values.
"""
sec = sections[0]
field = sec["field"]
raw = item.get(field)
if raw is None:
return None
if isinstance(raw, list):
return [float(v) for v in raw]
return [float(raw)]
def _process_sections(
self,
item: dict,
sections: list,
config,
tokenizer,
*,
is_top_level: bool = False,
):
"""Process a list of sections into ``(ids, loss_mask)``.
Returns ``(None, None)`` if the item should be skipped.
"""
all_ids: list[int] = []
loss_mask: list[int] = []
has_template = any(s.get("template") for s in sections)
is_text_config = not has_template and all(
s["action"] == "train" for s in sections
)
if is_top_level and has_template and tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
loss_mask.append(0)
first_section = True
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
add_special = sec.get(
"add_special_tokens", not use_template and first_section
)
if use_template:
success = self._append_template_section(
item, field, action, tokenizer, config, all_ids, loss_mask
)
if not success:
continue
else:
success = self._append_text_section(
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
)
if not success:
continue
first_section = False
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
if is_top_level and has_template and len(all_ids) <= 1:
return None, None
return all_ids, loss_mask
def _append_template_section(
self, item, field, action, tokenizer, config, all_ids, loss_mask
):
messages = item.get(field)
if not isinstance(messages, list) or not messages:
return False
for msg in messages:
role = msg.get("role", "")
act = _resolve_action(action, role, config)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
all_ids.extend(ids)
val = 1 if act == "train" else 0
loss_mask.extend([val] * len(ids))
return True
def _append_text_section(
self,
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
):
text = str(item.get(field, ""))
if not text.strip():
return False
if is_text_config:
pp = config.preprocessing
if pp.min_chars > 0 and len(text) < pp.min_chars:
return False
if len(text) > pp.max_chars:
return False
ids = tokenizer.encode(text, add_special_tokens=add_special)
all_ids.extend(ids)
val = 1 if action == "train" else 0
loss_mask.extend([val] * len(ids))
return True
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
all_ids: list[int] = []
loss_mask: list[int] = []
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
values = item.get(field)
if not isinstance(values, list):
continue
for val in values:
if use_template:
if isinstance(val, list):
wrapper = {field: val}
self._append_template_section(
wrapper,
field,
action,
tokenizer,
config,
all_ids,
loss_mask,
)
else:
wrapper = {field: str(val)}
self._append_text_section(
wrapper,
field,
action,
tokenizer,
False,
False,
config,
all_ids,
loss_mask,
)
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
return all_ids, loss_mask

View File

@ -1,257 +0,0 @@
"""Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
sharding and flush to ``.h5`` / ``.bin`` storage.
"""
import json
import os
from collections import defaultdict
from itertools import chain
from typing import List, Optional, Tuple
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import SectionedMaskBuilder
from astrai.tokenize import AutoTokenizer
_STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
def _truncate(seq: list, max_len: int, mode: str) -> list:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
def pack_sequences(
sequences: List[list],
max_packed_len: int,
strategy: str,
truncation_mode: str,
) -> List[Tuple[int, int]]:
"""Pack *sequences* into bins and return a reorder plan.
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
All keys (sequence, loss_mask, ) must be reordered and truncated
identically according to this plan.
Supported *strategy* values:
- ``"simple"``: sequential, no reordering.
- ``"bfd"``: best-fit decreasing bin packing.
"""
n = len(sequences)
if strategy == "simple":
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
Usage::
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
"""
def __init__(
self,
config: PipelineConfig,
input_paths: list[str],
output_dir: str,
tokenizer_path: str,
):
os.makedirs(output_dir, exist_ok=True)
self.config = config
self.paths = input_paths
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = SectionedMaskBuilder()
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
domains: dict = defaultdict(lambda: defaultdict(list))
total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int)
count = 0
pp = self.config.preprocessing
for item in tqdm.tqdm(
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
):
if pp.max_items and count >= pp.max_items:
break
result = self.transform(item)
if result is None:
continue
domain = result.pop("domain", "__default__")
is_multi = bool(getattr(self.config.input, "sources", None))
if is_multi:
ids = self._primary_ids(result)
else:
ids = result.pop("sequence")
result["sequence"] = ids
if not ids:
continue
bucket = domains[domain]
self._align_bucket(bucket, result, ids, is_multi)
for key, val in result.items():
bucket[key].append(val)
count += 1
total_tokens += len(ids)
if total_tokens >= self.config.output.max_tokens_per_shard:
self._flush(domains, shard_idx)
domains.clear()
total_tokens = 0
if total_tokens > 0:
self._flush(domains, shard_idx)
print(f"Done. {count} documents tokenized.")
@staticmethod
def _primary_ids(result: dict) -> list:
"""Return the first list-valued entry in *result* as the primary id
sequence for token counting."""
for val in result.values():
if isinstance(val, list) and val and isinstance(val[0], int):
return val
return []
@staticmethod
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
"""Pad previously-accumulated keys that are missing from *result*."""
for key in list(bucket.keys()):
if key in result:
continue
if is_multi:
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
bucket[key].append(pad)
else:
bucket[key].append([1] * len(ids))
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
chunk_dir = os.path.join(self.output_dir, domain)
pp = self.config.preprocessing
if pp.packing_strategy != "simple" and "sequence" in keys:
plan = pack_sequences(
keys["sequence"],
pp.max_packed_len,
pp.packing_strategy,
pp.truncation_mode,
)
reordered = defaultdict(list)
for orig_idx, truncated_len in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
)
)
keys = reordered
tensors = {}
for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get(
self.config.output.dtype.get(key, "int32"), torch.int32
)
tensors[key] = [
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
]
pid_mode = self.config.output.position_ids_mode
if pid_mode and pid_mode != "none" and "sequence" in tensors:
pos_ids = []
if pid_mode == "doc_reset":
for item in keys["sequence"]:
pos_ids.extend(range(len(item)))
else:
total = sum(len(item) for item in keys["sequence"])
pos_ids = list(range(total))
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
fmt = self.config.output.storage_format
if fmt == "bin":
save_bin(shard_path, tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors[first_key][0].numel():,} tokens)"
)

View File

@ -1,21 +0,0 @@
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class OptimizerProtocol(Protocol):
def step(self, closure=None): ...
def zero_grad(self): ...
@property
def param_groups(self) -> Any: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
@runtime_checkable
class SchedulerProtocol(Protocol):
def step(self): ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
def get_last_lr(self): ...

View File

@ -1,9 +1,7 @@
import io
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Optional
import safetensors.torch as st
import torch
@ -11,172 +9,75 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: Union[str, Path]):
st.save_file(state_dict, str(path))
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: Union[str, Path]):
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
with open(str(path), "r") as f:
return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: Union[str, Path]):
torch.save(obj, str(path))
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
path = Path(path)
if not broadcast or not dist.is_initialized():
return load_safetensors(path)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(path)
specs = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass
class Checkpoint:
state_dict: Dict[str, Any] = field(default_factory=dict)
epoch: int = 0
iteration: int = 0
extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
config: Dict[str, Any] = field(default_factory=dict)
def __init__(
self,
state_dict: Dict[str, Any],
epoch: int = 0,
iteration: int = 0,
extra: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.extra = extra or {}
self.meta = meta or {}
def save(
self,
save_dir: str,
) -> None:
def save(self, save_dir: str):
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
if get_rank() != 0:
return
rank = get_rank()
if rank == 0:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
**self.meta,
}
save_json(meta, save_path / _META_FILE)
save_json(self.config, save_path / _CONFIG_FILE)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra:
for key, value in self.extra.items():
save_torch(value, save_path / f"{key}.pt")
torch.save(value, save_path / f"{key}.pt")
@classmethod
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
def load(
cls,
save_dir: str,
) -> "Checkpoint":
rank = get_rank()
save_path = Path(save_dir)
meta = load_json(save_path / _META_FILE, broadcast)
config = load_json(save_path / _CONFIG_FILE, broadcast)
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
meta = {}
if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f:
meta = json.load(f)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = {}
for f in sorted(save_path.iterdir()):
if f.suffix == ".pt":
extra[f.stem] = load_torch(f, broadcast=broadcast)
for f in save_path.iterdir():
if f.suffix == ".pt" and f.stem not in ("meta",):
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
return cls(
state_dict=state_dict,
epoch=meta.get("epoch", 0),
iteration=meta.get("iteration", 0),
extra=extra,
config=config,
epoch=meta["epoch"],
iteration=meta["iteration"],
extra=extra or None,
)

View File

@ -1,10 +1,13 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from jinja2 import Template
# Message type for chat messages
type MessageType = Dict[str, Any]
@dataclass
class ChatTemplate:
"""A chat template with Jinja2 rendering support.
@ -12,24 +15,23 @@ class ChatTemplate:
name: Unique identifier for the template.
template_str: Jinja2 template string.
description: Optional description.
default_variables: Optional dictionary of default variable values.
default_variables: Optional dictionary of default variable values
that will be passed to the template if not overridden during rendering.
special_tokens: Optional dictionary mapping token names to their string values.
These tokens are automatically added to the template variables.
"""
def __init__(
self,
name: str = "",
template_str: str = "",
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
):
self.name = name
self.template_str = template_str
self.description = description
self.default_variables = default_variables or {}
self.special_tokens = special_tokens or {}
self._compiled: Template = Template(template_str)
name: str
template_str: str
description: str = ""
default_variables: Dict[str, Any] = None
special_tokens: Dict[str, str] = None
def __post_init__(self):
if self.default_variables is None:
self.default_variables = {}
if self.special_tokens is None:
self.special_tokens = {}
@classmethod
def from_string(
@ -41,7 +43,7 @@ class ChatTemplate:
) -> "ChatTemplate":
"""Create a ChatTemplate instance directly from a template string."""
return cls(
name="",
name="", # empty name for adhoc templates
template_str=template_str,
description=description,
default_variables=default_variables,
@ -71,4 +73,5 @@ class ChatTemplate:
if system_prompt is not None:
variables["system_prompt"] = system_prompt
return self._compiled.render(**variables)
jinja_template = Template(self.template_str)
return jinja_template.render(**variables)

View File

@ -4,17 +4,17 @@ from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2
X = G
X = G.bfloat16()
scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale
if steps == 0:
return X
return X.type_as(G)
a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * (A @ B)
return X
return X.type_as(G)
class Muon(Optimizer):
@ -50,94 +50,64 @@ class Muon(Optimizer):
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_2d, params_1d = [], []
grads_2d, grads_1d = [], []
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2:
params_2d.append(p)
grads_2d.append(p.grad)
self._muon_update(p, grad, group)
else:
params_1d.append(p)
grads_1d.append(p.grad)
if params_2d:
self._muon_update_foreach(params_2d, grads_2d, group)
if params_1d:
self._adamw_update_foreach(params_1d, grads_1d, group)
self._adamw_update(p, grad, group)
return loss
def _muon_update_foreach(self, params_2d, grads_2d, group):
def _muon_update(self, p, grad, group):
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
state = self.state[p]
if wd != 0:
torch._foreach_mul_(params_2d, 1 - lr * wd)
p.mul_(1 - lr * wd)
if nesterov:
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
grad = grad.add(p, alpha=wd)
bufs = []
for p, grad in zip(params_2d, grads_2d):
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
bufs.append(state["momentum_buffer"])
buf = state["momentum_buffer"]
buf.lerp_(grad, 1 - momentum)
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
for p, buf in zip(params_2d, bufs):
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update_foreach(self, params_1d, grads_1d, group):
def _adamw_update(self, p, grad, group):
lr = group["adamw_lr"]
betas = group["adamw_betas"]
eps = group["adamw_eps"]
wd = group["adamw_wd"]
steps: list[int] = []
exp_avgs, exp_avg_sqs = [], []
has_state = []
for p in params_1d:
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
has_state.append(False)
else:
has_state.append(True)
state["step"] += 1
steps.append(state["step"])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
state["step"] += 1
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = betas
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
bias_correction1 = [1 - beta1**s for s in steps]
bias_correction2 = [1 - beta2**s for s in steps]
step = state["step"]
bias1 = 1 - beta1**step
bias2 = 1 - beta2**step
if wd != 0:
torch._foreach_mul_(params_1d, 1 - lr * wd)
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
denom = torch._foreach_sqrt(denom)
torch._foreach_add_(denom, eps)
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
p.mul_(1 - lr * wd)
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
p.addcdiv_(exp_avg / bias1, denom, value=-lr)

View File

@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")

View File

@ -1,5 +1,6 @@
"""Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
@ -7,14 +8,26 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict."""
ref_model = model_fn()
ref_model.load_state_dict(state_dict)
def unwrap_model(model: nn.Module) -> nn.Module:
"""Unwrap DDP wrapper if present to get the original model."""
if isinstance(model, DDP):
return model.module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
ref_model.requires_grad_(False)
ref_model.eval()
return ref_model
@ -68,22 +81,6 @@ def get_logprobs(
return token_logprobs * shifted_mask
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
S = position_ids.size(1)
device = position_ids.device
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
doc_ids = torch.cat(
[
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
boundaries.long().cumsum(dim=1),
],
dim=1,
)
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
return (same_doc & causal).unsqueeze(1)
class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
@ -92,8 +89,6 @@ class BaseStrategy(ABC):
):
self.model = model
self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs
@abstractmethod
@ -128,7 +123,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
"""
@classmethod
def _validate_component(cls, strategy_cls: type):
def _validate_component(cls, strategy_cls: type) -> None:
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
@ -196,19 +191,15 @@ class SFTStrategy(BaseStrategy):
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids, position_ids, loss_mask = (
input_ids, target_ids, loss_mask = (
batch["input_ids"],
batch["target_ids"],
batch["position_ids"],
batch["loss_mask"],
)
ignore_index = -100
input_mask = make_doc_boundary_mask(position_ids)
logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
logits = self.model(
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
@ -237,9 +228,7 @@ class DPOStrategy(BaseStrategy):
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.ref_model = create_ref_model(model)
self.beta = beta
self.reduction = reduction
@ -293,9 +282,7 @@ class GRPOStrategy(BaseStrategy):
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.ref_model = create_ref_model(model)
self.clip_eps = clip_eps
self.kl_coef = kl_coef
self.group_size = group_size
@ -305,7 +292,8 @@ class GRPOStrategy(BaseStrategy):
def sync_ref_model(self):
"""Copy current model weights to ref model."""
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
ref_state = self.model.state_dict()
self.ref_model.load_state_dict(ref_state)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device, get_rank
from astrai.parallel.setup import get_current_device
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
@ -51,15 +51,18 @@ class TrainCallback(Protocol):
def on_epoch_end(self, context: TrainContext):
"""Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext):
"""Called at the beginning of each step."""
def on_step_end(self, context: TrainContext):
"""Called at the end of each step."""
def on_batch_begin(self, context: TrainContext):
"""Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext):
"""Called at the end of each batch."""
def on_optimizer_step(self, context: TrainContext):
"""Called on every optimizer step (sync step only)."""
def on_error(self, context: TrainContext):
"""Called when an error occurs during training."""
@ -85,7 +88,7 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_optimizer_step(self, context: TrainContext):
def on_step_begin(self, context: TrainContext):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -137,31 +140,44 @@ class CheckpointCallback(TrainCallback):
save_dir: str,
interval: int,
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext):
state_dict = context.executor.unwrap_model(context.model)
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
config=context.model_config,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
@ -183,6 +199,12 @@ class CheckpointCallback(TrainCallback):
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):
@ -191,7 +213,7 @@ class ProgressBarCallback(TrainCallback):
"""
def __init__(
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
):
self.num_epoch = num_epoch
self.log_interval = log_interval
@ -204,7 +226,7 @@ class ProgressBarCallback(TrainCallback):
context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True,
file=self.file or sys.stdout,
file=self.file,
)
@only_on_rank(0)
@ -322,7 +344,7 @@ class ValidationCallback(TrainCallback):
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_optimizer_step(self, context: TrainContext):
def on_step_end(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config

View File

@ -1,18 +1,15 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler
from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_json, load_model_weights
from astrai.serialization import Checkpoint
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -21,12 +18,10 @@ class TrainContext:
model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None)
optimizer: OptimizerProtocol = field(default=None)
scheduler: SchedulerProtocol = field(default=None)
optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
model_config: dict = field(default_factory=dict)
executor: BaseExecutor = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
@ -45,91 +40,49 @@ class TrainContextBuilder:
config: TrainConfig,
):
self.config = config
self._resume_dir: Optional[str] = None
self._checkpoint: Optional[Checkpoint] = None
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._resume_dir = resume_dir
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
cfg = self.config
device = get_current_device()
executor = ExecutorFactory.create(
cfg.parallel_mode,
grad_accum_steps=cfg.grad_accum_steps,
**cfg.executor_kwargs,
)
model = cfg.model_fn()
model = model.to(device=device)
model_config = {}
if self._resume_dir:
config_path = Path(self._resume_dir) / "config.json"
if config_path.exists():
model_config = load_json(config_path)
if not model_config and hasattr(model, "config"):
model_config = model.config.to_dict()
context = TrainContext(
model=model,
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
model_config=model_config,
executor=executor,
config=self.config,
)
if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict
if checkpoint.config:
context.model_config = checkpoint.config
device = get_current_device()
context.model = context.model.to(device=device)
if self.config.nprocs > 1 and self.config.parallel_wrapper:
context.model = self.config.parallel_wrapper(context.model)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:
checkpoint = None
state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = cfg.start_epoch
context.iteration = cfg.start_batch
context.checkpoint = checkpoint
if cfg.lora is not None:
inject_lora(
model,
r=cfg.lora.r,
alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules),
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
)
context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
train_dataset = cfg.dataset
val_dataset = cfg.val_dataset
if val_dataset is None and cfg.val_split is not None:
n_total = len(cfg.dataset)
n_val = max(1, int(n_total * cfg.val_split))
n_train = n_total - n_val
generator = torch.Generator().manual_seed(cfg.random_seed)
train_dataset, val_dataset = random_split(
cfg.dataset, [n_train, n_val], generator=generator
)
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=train_dataset,
data_source=cfg.dataset,
start_epoch=context.epoch,
start_iter=sampler_offset,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
train_dataset,
cfg.dataset,
batch_size=cfg.batch_per_device,
sampler=sampler,
num_workers=cfg.num_workers,
@ -137,16 +90,16 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor,
)
if val_dataset is not None:
if cfg.val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=val_dataset,
data_source=cfg.val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
val_dataset,
cfg.val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,
@ -154,30 +107,11 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor,
)
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
model=context.model,
train_type=cfg.strategy,
train_type=self.config.strategy,
device=device,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs,
**self.config.extra_kwargs,
)
return context

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
@ -33,6 +34,7 @@ class Trainer:
"checkpoint",
cfg.ckpt_dir,
cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn,
),
CallbackFactory.create(
"metric_logger",
@ -53,35 +55,33 @@ class Trainer:
if method:
method(context)
def _trainer_loop(self, resume_dir: Optional[str] = None):
context = (
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
)
executor = context.executor
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
grad_accum_steps = cfg.grad_accum_steps
for epoch in range(context.epoch, context.config.n_epoch):
for epoch in range(context.epoch, cfg.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
with executor.accumulate(context.model):
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
stand_loss = loss / grad_accum_steps
stand_loss.backward()
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if executor.sync_gradients:
self._call_callbacks("on_optimizer_step", context)
if context.iteration % grad_accum_steps == 0:
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
if context.scheduler:
context.scheduler.step()
@ -89,13 +89,13 @@ class Trainer:
self._call_callbacks("on_epoch_end", context)
except Exception as e:
logger.error("Training failed: %s", str(e), exc_info=True)
logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, resume_dir: Optional[str] = None):
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
@ -105,5 +105,5 @@ class Trainer:
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
resume_dir=resume_dir,
checkpoint=checkpoint,
)

View File

@ -1,336 +0,0 @@
"""HumanEval code generation benchmark.
Generates n completions per problem, extracts function bodies, executes
against hidden tests, and computes pass@k.
Usage::
python scripts/tools/evaluate_humaneval.py --param_path ./params \
--data_path HumanEval.jsonl.gz --output results.json \
--num_samples 200 --temperature 0.8 --max_tokens 512
"""
import argparse
import json
import os
import re
import signal
import sys
from math import prod
from multiprocessing import Process, Queue
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import tqdm
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
HUMANEVAL_URL = (
"https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz"
)
_STOP_SEQUENCES = [
"\nclass ",
"\ndef ",
"\n# ",
"\nif __name__",
"\nprint(",
"\n\n\n",
]
def _download_humaneval(data_path: str):
if os.path.exists(data_path):
return
import gzip
import urllib.request
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
print(f"Downloading HumanEval from {HUMANEVAL_URL} ...")
tmp = data_path + ".tmp"
urllib.request.urlretrieve(HUMANEVAL_URL, tmp)
with gzip.open(tmp, "rb") as f_in:
with open(data_path, "wb") as f_out:
f_out.write(f_in.read())
os.remove(tmp)
print(f" saved to {data_path}")
def _load_problems(data_path: str) -> List[dict]:
problems = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
problems.append(json.loads(line))
return problems
def _extract_function_body(code: str, entry_point: str) -> Optional[str]:
"""Extract the function body from a completion."""
pattern = rf"def\s+{re.escape(entry_point)}\b[^:]*:"
match = re.search(pattern, code)
if not match:
# Use the full code as-is if we can't find the function
return code
body_start = match.end()
lines = code[body_start:].split("\n")
body_lines = []
started = False
for line in lines:
stripped = line.rstrip()
if not stripped and not started:
continue
if not stripped and started:
body_lines.append("")
continue
if not started:
started = True
if stripped.lstrip() == stripped and started:
break
body_lines.append(stripped)
body = "\n".join(body_lines)
if not body.strip():
return None
return body
def _trim_stop_sequences(text: str) -> str:
for stop in _STOP_SEQUENCES:
idx = text.find(stop)
if idx != -1:
text = text[:idx]
return text
def _execute_code(problem: dict, completion: str, timeout: float = 3.0) -> bool:
"""Run the completion against hidden tests in a subprocess."""
def _worker(queue, full_code):
try:
namespace = {}
exec(full_code, namespace)
check = namespace.get("check")
if check is None:
queue.put(False)
return
check(namespace.get(problem["entry_point"]))
queue.put(True)
except Exception:
queue.put(False)
full_code = problem["prompt"] + completion + "\n" + problem["test"]
queue: Queue = Queue()
proc = Process(target=_worker, args=(queue, full_code))
proc.start()
proc.join(timeout)
if proc.is_alive():
proc.terminate()
proc.join()
return False
try:
return queue.get_nowait()
except Exception:
return False
def _pass_at_k(n: int, c: int, k: int) -> float:
"""Unbiased estimator of pass@k."""
if n - c < k:
return 1.0
return 1.0 - float(prod(1.0 - k / np.arange(n - c + 1, n + 1)))
def _deduplicate(completions: List[str]) -> List[str]:
seen = set()
unique = []
for c in completions:
if c not in seen:
seen.add(c)
unique.append(c)
return unique
def _generate(
engine: InferenceEngine,
prompt: str,
num_samples: int,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
batch_size: int,
) -> List[str]:
batches = [prompt] * min(batch_size, num_samples)
completions = []
remaining = num_samples
while remaining > 0:
current = min(batch_size, remaining)
batch_prompts = batches[:current]
outputs = engine.generate(
prompt=batch_prompts,
stream=False,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
if isinstance(outputs, str):
outputs = [outputs]
completions.extend(outputs)
remaining -= current
return _deduplicate(completions)
def evaluate(
engine: InferenceEngine,
problems: List[dict],
num_samples: int,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
batch_size: int,
k_values: Tuple[int, ...] = (1, 10, 100),
) -> Dict:
results = {}
all_pass_at_k = {k: [] for k in k_values}
for problem in tqdm.tqdm(problems, desc="HumanEval", unit="problem"):
task_id = problem["task_id"]
prompt = problem["prompt"]
entry_point = problem["entry_point"]
raw_completions = _generate(
engine,
prompt,
num_samples,
max_tokens,
temperature,
top_p,
top_k,
batch_size,
)
completions = []
for raw in raw_completions:
trimmed = _trim_stop_sequences(raw)
body = _extract_function_body(trimmed, entry_point)
if body:
completions.append(body)
passed = 0
for comp in completions:
if _execute_code(problem, comp):
passed += 1
n = len(completions)
c = passed
result = {"task_id": task_id, "n": n, "passed": c}
for k in k_values:
result[f"pass@{k}"] = round(_pass_at_k(n, c, k), 4)
all_pass_at_k[k].append(_pass_at_k(n, c, k))
results[task_id] = result
summary = {}
for k in k_values:
vals = all_pass_at_k[k]
summary[f"pass@{k}"] = round(float(np.mean(vals)), 4)
results["_summary"] = summary
return results
def main():
parser = argparse.ArgumentParser(description="HumanEval benchmark")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_path",
type=str,
default="./humaneval/HumanEval.jsonl",
help="HumanEval JSONL file (auto-download if missing)",
)
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
parser.add_argument(
"--num_samples",
type=int,
default=200,
help="Completions per problem",
)
parser.add_argument(
"--max_tokens", type=int, default=512, help="Max generation tokens"
)
parser.add_argument(
"--temperature", type=float, default=0.8, help="Sampling temperature"
)
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
parser.add_argument(
"--batch_size", type=int, default=1, help="Inference batch size"
)
parser.add_argument(
"--problems",
type=int,
nargs="+",
default=None,
help="Specific problem indices (0-based)",
)
args = parser.parse_args()
_download_humaneval(args.data_path)
problems = _load_problems(args.data_path)
if args.problems:
problems = [problems[i] for i in args.problems if i < len(problems)]
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
model.to(device="cuda", dtype=torch.bfloat16)
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=args.batch_size,
)
results = evaluate(
engine=engine,
problems=problems,
num_samples=args.num_samples,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
batch_size=args.batch_size,
k_values=(1, 10, 100),
)
summary = results.pop("_summary")
print(f"\n{'=' * 60}")
for k, v in summary.items():
print(f" {k}: {v:.2%}")
print(f"{'=' * 60}")
if args.output:
results["_summary"] = summary
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Results saved to {args.output}")
engine.shutdown()
if __name__ == "__main__":
main()

View File

@ -1,319 +0,0 @@
"""MMLU evaluation via log-likelihood ranking."""
import argparse
import csv
import json
import os
import shutil
import tarfile
import requests
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _download_and_extract(url: str, data_dir: str):
tar_path = os.path.join(data_dir, "data.tar")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
resp = requests.get(url, stream=True, timeout=300)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
with open(tar_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
print("Extracting...")
with tarfile.open(tar_path, "r") as tf:
tf.extractall(data_dir)
os.remove(tar_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "data")
if os.path.exists(src):
for item in os.listdir(src):
src_item = os.path.join(src, item)
dst_item = os.path.join(data_dir, item)
if os.path.exists(dst_item):
if os.path.isdir(dst_item):
shutil.rmtree(dst_item)
else:
os.remove(dst_item)
os.rename(src_item, dst_item)
os.rmdir(src)
print(f"MMLU data saved to {data_dir}")
def _strip_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :].strip()
return text
def load_csv(path: str) -> list[dict]:
data = []
with open(path, "r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 6:
continue
if row[0].strip().lower() == "question":
continue
data.append(
{
"question": row[0].strip(),
"A": _strip_prefix(row[1].strip(), "A)"),
"B": _strip_prefix(row[2].strip(), "B)"),
"C": _strip_prefix(row[3].strip(), "C)"),
"D": _strip_prefix(row[4].strip(), "D)"),
"answer": row[5].strip(),
}
)
return data
def build_prompt(
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
) -> str:
prompt = ""
if n_shot > 0 and dev_data:
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
for item in dev_data[:n_shot]:
prompt += f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {item[k]}\n"
prompt += f"Answer: {item['answer']}\n\n"
prompt += f"Question: {question}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {choices[k]}\n"
prompt += "Answer:"
return prompt
def apply_chat(
tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None
) -> str:
"""Wrap raw MMLU prompt in the model's chat template format.
For few-shot, prepend example Q&A pairs as a second user/assistant exchange.
"""
messages = []
if n_shot > 0 and dev_data:
for item in dev_data[:n_shot]:
q = f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
q += f"{k}. {item[k]}\n"
q += "Answer:"
messages.append({"role": "user", "content": q})
messages.append({"role": "assistant", "content": item["answer"]})
messages.append({"role": "user", "content": raw_prompt})
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float:
choice_text = choice_letter
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids
max_len = model.config.max_len
if len(input_ids) > max_len:
overflow = len(input_ids) - max_len
input_ids = input_ids[overflow:]
ctx_len = len(input_ids) - len(choice_ids)
else:
ctx_len = len(context_ids)
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits = model(input_tensor)["logits"][0]
score = 0.0
for i, tid in enumerate(choice_ids):
pos = ctx_len - 1 + i
if pos >= len(logits):
break
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
return score
def evaluate_subject(
model,
tokenizer,
subject: str,
test_data: list[dict],
dev_data: list[dict] | None,
device: str,
n_shot: int,
) -> tuple[float, int, int]:
correct = 0
total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
raw_prompt = build_prompt(
item["question"], item, subject, n_shot, dev_data or []
)
context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or [])
context_ids = tokenizer.encode(context)
scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D")
}
if max(scores, key=scores.get) == item["answer"]:
correct += 1
total += 1
return correct / total, correct, total
def main():
parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
)
parser.add_argument("--download", action="store_true", help="Download MMLU data")
parser.add_argument(
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
)
parser.add_argument(
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
)
parser.add_argument("--output", type=str, help="Output JSON path")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16" if torch.cuda.is_available() else "float32",
help="Torch dtype",
)
args = parser.parse_args()
if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
model.eval()
subjects = args.subjects or MMLU_SUBJECTS
results = {}
total_correct = 0
total_questions = 0
for subject in subjects:
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
test_path = os.path.join(
args.data_dir, args.split, f"{subject}_{args.split}.csv"
)
if not os.path.exists(test_path):
print(f" Skipping {subject}: test file not found")
continue
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
test_data = load_csv(test_path)
acc, corr, tot = evaluate_subject(
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
)
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
total_correct += corr
total_questions += tot
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
overall = total_correct / total_questions if total_questions else 0
print(f"\n{'=' * 70}")
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
results["_overall"] = {
"accuracy": round(overall, 4),
"correct": total_correct,
"total": total_questions,
}
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file(
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
):
# Load model and tokenizer
model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f:
@ -44,8 +44,8 @@ def process_file(
for seq in batch_encoded:
pad_len = max_len - len(seq)
padded_seq = seq + [tokenizer.pad_id] * pad_len
mask = [True] * len(seq) + [False] * pad_len
padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [False] * pad_len + [True] * len(seq)
padded_ids.append(padded_seq)
masks.append(mask)
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument(
"--param_path", type=str, required=True, help="Path to the model directory."
"--model_dir", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file."

View File

@ -1,38 +0,0 @@
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
import argparse
from astrai.config.preprocess_config import PipelineConfig
from astrai.preprocessing.pipeline import Pipeline
def main():
parser = argparse.ArgumentParser(
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
)
parser.add_argument(
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
)
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
parser.add_argument(
"--config", "-c", required=True, help="Path to pipeline config JSON"
)
parser.add_argument(
"--tokenizer_path",
default="params",
help="Path to tokenizer directory (default: params)",
)
args = parser.parse_args()
config = PipelineConfig.from_json(args.config)
Pipeline(
config=config,
input_paths=args.inputs,
output_dir=args.output_dir,
tokenizer_path=args.tokenizer_path,
).run()
if __name__ == "__main__":
main()

View File

@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser.add_argument(
"--param_path",
"--param-path",
type=Path,
default=None,
help="Path to model parameters (default: project_root/params)",

View File

@ -2,13 +2,16 @@ import argparse
import os
from functools import partial
import safetensors.torch as st
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.model.components.decoder_block import DecoderBlock
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
@ -116,12 +119,6 @@ def parse_args() -> argparse.Namespace:
default=0.05,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument(
"--gradient_checkpointing",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable activation checkpointing for DecoderBlock modules.",
)
parser.add_argument(
"--ckpt_interval",
@ -135,36 +132,6 @@ def parse_args() -> argparse.Namespace:
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument(
"--val_split",
type=float,
default=None,
help="Ratio to split from training dataset for validation (e.g. 0.05).",
)
parser.add_argument(
"--val_step",
type=int,
default=1000,
help="Number of optimizer steps between validation runs.",
)
parser.add_argument(
"--metrics",
nargs="*",
default=["loss", "lr"],
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
)
parser.add_argument(
"--log_dir",
type=str,
default="checkpoint/logs",
help="Directory for metric logs.",
)
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="Number of batch iterations between metric logs.",
)
parser.add_argument(
"--grpo_sync_interval",
type=int,
@ -178,32 +145,7 @@ def parse_args() -> argparse.Namespace:
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument(
"--master_addr",
type=str,
default="localhost",
help="Master node address for distributed training.",
)
parser.add_argument(
"--master_port",
type=str,
default="29500",
help="Master node port for distributed training.",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
help="Distributed training backend.",
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp", "fsdp"],
help="Parallel training strategy (none, ddp, fsdp).",
)
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
@ -220,11 +162,21 @@ def parse_args() -> argparse.Namespace:
return args
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def ddp_wrap(model: nn.Module):
local_rank = get_rank()
ddp_model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
static_graph=True,
find_unused_parameters=False,
gradient_as_bucket_view=True,
broadcast_buffers=False,
)
return ddp_model
def create_optimizer(model, **kwargs) -> optim.Optimizer:
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -234,6 +186,12 @@ def create_scheduler(
return SchedulerFactory.create(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
def compute_total_steps(
dataset_len: int,
n_epoch: int,
@ -264,11 +222,6 @@ def train(
warmup_ratio: float,
ckpt_interval: int,
ckpt_dir: str,
val_split: float,
val_step: int,
metrics: list[str],
log_dir: str,
log_interval: int,
dpo_beta: float,
grpo_clip_eps: float,
grpo_kl_coef: float,
@ -282,21 +235,14 @@ def train(
random_seed: int,
num_workers: int,
pin_memory: bool,
gradient_checkpointing: bool,
window_size: int,
stride: int,
nprocs: int,
parallel_mode: str,
device_type: str,
backend: str,
master_addr: str,
master_port: str,
start_method: str,
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config
config_path = os.path.join(param_path, "config.json")
@ -305,6 +251,17 @@ def train(
if window_size is None:
window_size = config.max_len
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
model = AutoRegressiveLM(config)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")
if os.path.exists(weights_path):
state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
model = model.to(dtype=torch.bfloat16)
strategy_kwargs = {
"beta": dpo_beta,
"label_smoothing": label_smoothing,
@ -314,12 +271,6 @@ def train(
"sync_interval": grpo_sync_interval,
}
executor_kwargs = {
"gradient_as_bucket_view": True,
"broadcast_buffers": False,
}
model_fn = partial(create_model, config)
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
@ -350,10 +301,8 @@ def train(
},
)
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
train_config = TrainConfig(
model_fn=model_fn,
model=model,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
@ -370,24 +319,15 @@ def train(
num_workers=num_workers,
pin_memory=pin_memory,
nprocs=nprocs,
backend=backend,
master_addr=master_addr,
master_port=master_port,
parallel_mode=parallel_mode,
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
device_type=device_type,
start_method=start_method,
val_split=val_split,
val_step=val_step,
metrics=metrics,
log_dir=log_dir,
log_interval=log_interval,
gradient_checkpointing_modules=grad_ckpt_modules,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs,
)
trainer = Trainer(train_config)
trainer.train(resume_dir=param_path)
trainer.train()
if __name__ == "__main__":

View File

@ -1,202 +0,0 @@
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS_CONFIG = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
}
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
_INSTRUCTION_SECTIONS = [
{"field": "prompt", "action": "mask", "add_special_tokens": True},
{"field": "response", "action": "train"},
]
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
def _build_chat_tokenizer():
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
def make_chat_config():
return PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
def make_dpo_chat_config():
return PipelineConfig(
input=InputConfig(
sources={
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": True}
]
},
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": True}
]
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": True}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_no_template_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{
"field": "prompt",
"action": "mask",
"add_special_tokens": True,
}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)

View File

@ -1,4 +1,3 @@
import os
import tempfile
import torch
@ -37,6 +36,7 @@ def test_single_process():
def test_checkpoint_with_extra():
"""Verify extra keys are saved as individual .pt files and loaded back."""
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
@ -52,6 +52,8 @@ def test_checkpoint_with_extra():
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
import os
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))

View File

@ -1,3 +1,4 @@
import json
import os
import numpy as np
@ -6,11 +7,12 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
H5Store,
StoreFactory,
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
StorageFactory,
detect_format,
load_bin,
save_bin,
load_json,
save_h5,
)
@ -98,7 +100,6 @@ def test_sft_dataset_with_random_data(base_test_env):
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
}
save_h5(test_dir, "sft_data", dummy_data)
@ -156,6 +157,111 @@ def test_dataset_with_custom_stride(base_test_env):
assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
@ -212,29 +318,37 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0
def test_store_unloaded_len():
"""Unloaded Store has __len__ == 0"""
store = H5Store()
assert len(store) == 0
assert store.keys == []
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_store_fetch_begin_equals_end(base_test_env):
"""Store.fetch with begin == end returns empty tensor"""
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
result = dataset.storage.fetch(10, 10, "sequence")
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0
def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError"""
store = H5Store()
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
with pytest.raises(RuntimeError, match="not loaded"):
store.fetch(0, 10, "sequence")
storage.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
@ -253,192 +367,54 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path)
def test_create_store_invalid_type():
"""StoreFactory.create raises ValueError for unknown type"""
def test_create_storage_invalid_type():
"""StorageFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StoreFactory.create("parquet")
StorageFactory.create("parquet")
def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time"""
import os
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
data_dir = os.path.join(test_dir, "json_pretok")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
save_h5(data_dir, "data", {"sequence": segs})
store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
assert result.tolist() == [3, 4, 5, 6, 7]
def test_save_load_bin_roundtrip(base_test_env):
"""save_bin + load_bin roundtrip preserves data"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
}
save_bin(test_dir, data)
result = load_bin(test_dir)
assert "sequence" in result
assert "loss_mask" in result
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
def test_mmap_store_load_and_fetch(base_test_env):
"""MmapStore loads bin data and fetches correctly"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
store = StoreFactory.create("bin")
store.load(test_dir)
assert len(store) == 200
assert "sequence" in store.keys
result = store.fetch(10, 20, "sequence")
assert result.tolist() == data["sequence"][0][10:20].tolist()
def test_mmap_dataset_load(base_test_env):
"""DatasetFactory.load auto-detects bin format"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) > 0
assert dataset.count == 200
assert dataset[0]["input_ids"].shape[0] == 64
def test_normalize_empty_key():
"""_normalize with empty tensor list does not crash"""
store = H5Store()
store._normalize({"sequence": []})
assert len(store) == 0
assert store.keys == ["sequence"]
def test_normalize_mixed_empty_key():
"""_normalize with empty + non-empty keys returns min=0"""
store = H5Store()
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
assert len(store) == 0
assert set(store.keys) == {"sequence", "loss_mask"}
def test_grpo_dataset_dtype(base_test_env):
"""GRPODataset returns correct dtypes"""
test_dir = base_test_env["test_dir"]
seq_len = 100
data = {
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"masks": [torch.ones(seq_len, dtype=torch.int32)],
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_dtype", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
item = dataset[0]
assert item["prompts"].dtype == torch.long
assert item["responses"].dtype == torch.long
assert item["masks"].dtype == torch.bool
assert item["rewards"].dtype == torch.float32
def test_grpo_dataset_load(base_test_env):
"""GRPODataset loads and returns correct keys"""
test_dir = base_test_env["test_dir"]
seq_len = 200
data = {
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"masks": [torch.ones(seq_len, dtype=torch.int64)],
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_test", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
assert len(dataset) > 0
item = dataset[0]
assert "prompts" in item
assert "responses" in item
assert "masks" in item
assert "rewards" in item
assert item["prompts"].shape[0] == 64
assert item["responses"].shape[0] == 64
def test_detect_format_bin_dir(base_test_env):
"""detect_format returns 'bin' for directory with .bin + meta.json"""
test_dir = base_test_env["test_dir"]
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
assert detect_format(test_dir) == "bin"
def test_store_fetch_multi_key(base_test_env):
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
test_dir = base_test_env["test_dir"]
save_h5(
test_dir,
"multi_key",
{
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
"loss_mask": [torch.ones(100, dtype=torch.int64)],
},
)
store = StoreFactory.create("h5")
store.load(test_dir)
result = store.fetch(10, 20, ["sequence", "loss_mask"])
assert isinstance(result, dict)
assert result["sequence"].shape[0] == 10
assert result["loss_mask"].shape[0] == 10
def test_store_fetch_out_of_bounds(base_test_env):
"""Store.fetch raises ValueError for out-of-bounds indices"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
store = StoreFactory.create("h5")
store.load(test_dir)
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(-1, 10, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(0, 51, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(50, 50, "sequence")
def test_dataset_load_explicit_storage_type(base_test_env):
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
assert len(dataset) > 0
assert dataset.count == 200

View File

@ -1,396 +0,0 @@
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
MaskBuilderFactory,
SectionedMaskBuilder,
)
from tests.data.conftest import (
_CHAT_SECTIONS,
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_chat_config,
make_dpo_chat_config,
make_grpo_config,
make_instruction_config,
make_text_config,
)
def test_chat_simple(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "sequence" in result
assert "loss_mask" in result
assert len(result["sequence"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["sequence"])
trained = sum(result["loss_mask"])
assert trained > 0
assert trained < total
def test_chat_mask_only_assistant(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
assert len(ids) == len(mask)
trained = [i for i, m in enumerate(mask) if m == 1]
masked = [i for i, m in enumerate(mask) if m == 0]
assert len(trained) > 0
assert len(masked) > 0
def test_chat_all_masked(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
def test_chat_empty_messages(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_chat_domain_extraction(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_chat_truncation(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, chat_tokenizer)
assert len(result["sequence"]) <= 10
assert len(result["loss_mask"]) == len(result["sequence"])
def test_instruction_basic(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
def test_instruction_prompt_masked(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_instruction_train_on_prompt(test_tokenizer):
config = PipelineConfig(
input=InputConfig(
sections=[
{"field": "prompt", "action": "train", "add_special_tokens": True},
{"field": "response", "action": "mask"},
]
),
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
def test_text_basic(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "sequence" in result
assert len(result["sequence"]) > 0
assert "loss_mask" not in result
def test_text_empty(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_text_truncation(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["sequence"]) <= 3
def test_sectioned_chat(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"]
def test_sectioned_instruction(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
)
builder = SectionedMaskBuilder()
item = {"prompt": "Q: Why?", "response": "A: Because."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
mask = result["loss_mask"]
assert mask[0] == 0
assert mask[-1] == 1
def test_sectioned_text(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "Hello world, this is a test."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "loss_mask" not in result
def test_sectioned_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_factory_registered():
names = MaskBuilderFactory._registry.list_names()
assert "sectioned" in names
def test_factory_create():
builder = MaskBuilderFactory.create("sectioned")
assert isinstance(builder, SectionedMaskBuilder)
def test_dpo_chat_basic(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
],
"rejected": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "5"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "chosen" in result
assert "rejected" in result
assert "chosen_mask" in result
assert "rejected_mask" in result
assert "domain" in result
assert len(result["chosen"]) == len(result["chosen_mask"])
assert len(result["rejected"]) == len(result["rejected_mask"])
assert sum(result["chosen_mask"]) > 0
assert sum(result["rejected_mask"]) > 0
def test_dpo_chosen_only_trained(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"rejected": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Go away"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert 0 in result["chosen_mask"]
assert 1 in result["chosen_mask"]
assert 0 in result["rejected_mask"]
assert 1 in result["rejected_mask"]
def test_dpo_missing_field_is_none(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
def test_grpo_basic(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "What is 2+2?"}],
"responses": ["4", "The answer is four", "Four", "2+2=4"],
"rewards": [1.0, 0.5, 0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "prompts" in result
assert "responses" in result
assert "masks" in result
assert "rewards" in result
assert len(result["responses"]) == len(result["masks"])
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
def test_grpo_response_tokens_all_trained(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A", "B"],
"rewards": [0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
masks = result["masks"]
assert all(m == 1 for m in masks)
assert len(masks) == len(result["responses"])
def test_grpo_single_reward(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A"],
"rewards": 0.9,
}
result = builder.build(item, config, chat_tokenizer)
assert result["rewards"] == [0.9]

View File

@ -1,77 +0,0 @@
import os
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
)
from tests.data.conftest import (
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_dpo_chat_config,
)
def test_default_values():
config = PipelineConfig()
assert config.version == 1
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
assert config.input.sections is None
def test_from_dict_flat():
data = {
"version": 1,
"input": {
"sections": [{"field": "messages", "action": "$role", "template": True}]
},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.sections == [
{"field": "messages", "action": "$role", "template": True}
]
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip():
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.sections == _INSTRUCTION_SECTIONS
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(temp_dir):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"}
def test_dpo_config_roundtrip(temp_dir):
config = make_dpo_chat_config()
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.sources is not None
assert "chosen" in loaded.input.sources
assert "rejected" in loaded.input.sources
assert loaded.input.sections is None

View File

@ -1,349 +0,0 @@
import json
import os
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from tests.data.conftest import (
_CHAT_SECTIONS,
_CHAT_TEMPLATE,
_INSTRUCTION_SECTIONS,
_SPECIAL_TOKENS_CONFIG,
_TEXT_SECTIONS,
make_dpo_chat_config,
make_grpo_no_template_config,
)
def test_filter_by_length():
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_full_text_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
assert meta["sequence"]["dtype"] == "int32"
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_dtype_override(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
with open(meta_path, "r") as f:
meta = json.load(f)
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "bool"
def test_dpo_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"chosen": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
],
"rejected": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Go away."},
],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_dpo_chat_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "chosen" in meta
assert "rejected" in meta
assert "chosen_mask" in meta
assert "rejected_mask" in meta
assert "sequence" not in meta
def test_grpo_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Question?",
"responses": ["Answer A", "Answer B"],
"rewards": [0.8, 0.3],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_grpo_no_template_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "prompts" in meta
assert "responses" in meta
assert "masks" in meta
assert "rewards" in meta
assert "sequence" not in meta

View File

@ -5,22 +5,21 @@ from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from astrai.inference import get_app
from astrai.inference import app
@pytest.fixture
def client():
"""Provide a test client for the FastAPI app."""
_app = get_app()
_app.state.server_config = {
app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16",
"param_path": None,
"max_batch_size": 1,
"_test": True,
}
_app.state.engine = None
return TestClient(_app)
app.state.engine = None
return TestClient(app)
@pytest.fixture
@ -50,5 +49,5 @@ def mock_engine():
@pytest.fixture
def loaded_model(client, mock_engine):
"""Simulate that the engine is loaded."""
get_app().state.engine = mock_engine
app.state.engine = mock_engine
return mock_engine

View File

@ -1,286 +0,0 @@
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
import json
from unittest.mock import MagicMock
import pytest
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
from astrai.inference.engine import GenerationRequest
def _make_ctx(**kwargs):
defaults = {
"resp_id": "test-123",
"created": 1000,
"model": "test-model",
"prompt_tokens": 10,
"completion_tokens": 5,
}
defaults.update(kwargs)
return GenContext(**defaults)
def _sse_payloads(events):
payloads = []
for chunk in events:
for line in chunk.strip().split("\n"):
if line.startswith("data: "):
try:
payloads.append(json.loads(line[6:]))
except json.JSONDecodeError:
pass
return payloads
class TestStopChecker:
def test_check_finds_match(self):
sc = StopChecker(["stop", "end"])
assert sc.check("hello stop world") == "stop"
def test_check_returns_none_when_no_match(self):
sc = StopChecker(["stop"])
assert sc.check("hello world") is None
def test_check_empty_sequences(self):
sc = StopChecker([])
assert sc.check("hello") is None
class TestGenContext:
def test_defaults(self):
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
assert ctx.completion_tokens == 0
def test_fields_mutable(self):
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
ctx.completion_tokens = 42
assert ctx.completion_tokens == 42
class TestStopInfo:
def test_defaults(self):
s = StopInfo()
assert s.matched is None
assert s.body == ""
assert s.yielded == ""
def test_with_values(self):
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
assert s.matched == "stop"
assert s.body == "hello stop"
assert s.yielded == "hello "
class TestOpenAIResponseBuilder:
@pytest.fixture
def builder(self):
builder = OpenAIResponseBuilder()
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hello")]
req.stop = None
req.model = "astrai"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hello"
builder.prepare(req, engine)
return builder
def test_prepare_returns_prompt_ctx_stops(self, builder):
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hi")]
req.stop = ["END"]
req.model = "gpt"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hi"
prompt, ctx, stops = builder.prepare(req, engine)
assert prompt == "Hi"
assert ctx.model == "gpt"
assert ctx.prompt_tokens == 0
assert stops == ["END"]
def test_prepare_no_stop_returns_empty_list(self, builder):
req = MagicMock()
req.messages = []
req.stop = None
req.model = "x"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = ""
_, _, stops = builder.prepare(req, engine)
assert stops == []
def test_format_stream_start(self, builder):
ctx = _make_ctx()
events = builder.format_stream_start(ctx)
payloads = _sse_payloads(events)
assert len(payloads) == 1
p = payloads[0]
assert p["object"] == "chat.completion.chunk"
assert p["choices"][0]["delta"]["role"] == "assistant"
assert p["choices"][0]["finish_reason"] is None
def test_format_chunk(self, builder):
event = builder.format_chunk("hello")
payload = json.loads(event.split("data: ", 1)[1])
assert payload["choices"][0]["delta"]["content"] == "hello"
assert payload["choices"][0]["finish_reason"] is None
def test_format_stream_end(self, builder):
ctx = _make_ctx(completion_tokens=5)
stop = StopInfo(matched="stop")
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
finish = payloads[0]
assert finish["choices"][0]["finish_reason"] == "stop"
usage = payloads[1]
assert usage["completion_tokens"] == 5
assert usage["total_tokens"] == 15
def test_format_response(self, builder):
ctx = _make_ctx()
stop = StopInfo()
resp = builder.format_response(ctx, "hello", stop)
assert resp["object"] == "chat.completion"
assert resp["choices"][0]["message"]["content"] == "hello"
assert resp["usage"]["prompt_tokens"] == 10
class TestAnthropicResponseBuilder:
@pytest.fixture
def builder(self):
builder = AnthropicResponseBuilder()
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hello")]
req.model = "claude"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hello"
req.system = None
builder.prepare(req, engine)
return builder
def test_prepare_messages(self, builder):
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hi")]
req.model = "claude"
req.system = None
req.stop_sequences = None
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hi"
prompt, ctx, stops = builder.prepare(req, engine)
assert prompt == "Hi"
assert stops == []
def test_prepare_with_stop_sequences(self, builder):
req = MagicMock()
req.messages = []
req.model = "x"
req.stop_sequences = ["stop", "end"]
req.system = None
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = ""
_, _, stops = builder.prepare(req, engine)
assert stops == ["stop", "end"]
def test_format_stream_start(self, builder):
ctx = _make_ctx(prompt_tokens=3)
events = builder.format_stream_start(ctx)
payloads = _sse_payloads(events)
assert len(payloads) == 2
assert payloads[0]["type"] == "message_start"
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
assert payloads[1]["type"] == "content_block_start"
def test_format_chunk(self, builder):
event = builder.format_chunk("tok")
payload = json.loads(event.split("data: ", 1)[1])
assert payload["type"] == "content_block_delta"
assert payload["delta"]["text"] == "tok"
def test_format_stream_end_no_stop(self, builder):
ctx = _make_ctx(completion_tokens=3)
stop = StopInfo()
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# content_block_stop, message_delta, message_stop
types = [p["type"] for p in payloads]
assert types == ["content_block_stop", "message_delta", "message_stop"]
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
ctx = _make_ctx(completion_tokens=7)
stop = StopInfo(
matched="END",
body="Hello world END extra",
yielded="Hello ",
)
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# unyielded delta, content_block_stop, message_delta, message_stop
types = [p["type"] for p in payloads]
assert types == [
"content_block_delta",
"content_block_stop",
"message_delta",
"message_stop",
]
assert payloads[0]["delta"]["text"] == "world "
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
assert payloads[2]["delta"]["stop_sequence"] == "END"
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
ctx = _make_ctx()
stop = StopInfo(
matched="END",
body="Hello END",
yielded="Hello ",
)
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# No unyielded delta (everything already sent)
types = [p["type"] for p in payloads]
assert types == ["content_block_stop", "message_delta", "message_stop"]
def test_format_response_with_stop_trims_content(self, builder):
ctx = _make_ctx()
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
resp = builder.format_response(ctx, "text STOP extra", stop)
assert resp["content"][0]["text"] == "text "
assert resp["stop_reason"] == "stop_sequence"
assert resp["stop_sequence"] == "STOP"
def test_format_response_no_stop(self, builder):
ctx = _make_ctx()
stop = StopInfo()
resp = builder.format_response(ctx, "full text", stop)
assert resp["content"][0]["text"] == "full text"
assert resp["stop_reason"] == "end_turn"
class TestGenerationRequestValidation:
def test_valid_params(self):
gr = GenerationRequest(
messages=[{"role": "user", "content": "hi"}],
top_k=50,
top_p=0.9,
temperature=0.7,
)
assert gr.top_k == 50
def test_invalid_top_p_raises(self):
with pytest.raises(ValueError, match="top_p"):
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
def test_invalid_top_k_raises(self):
with pytest.raises(ValueError, match="top_k"):
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
def test_invalid_temperature_raises(self):
with pytest.raises(ValueError, match="temperature"):
GenerationRequest(
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
)
def test_top_k_zero_valid(self):
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
assert gr.top_k == 0

View File

@ -173,21 +173,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
"""Tasks whose entire prompt is cached skip the prefill phase."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
scheduler.stop()
assert task_id.startswith("task_")

View File

@ -2,12 +2,12 @@
import pytest
from astrai.inference import get_app
from astrai.inference import app
def test_health_no_model(client):
"""GET /health should return 200 even when engine not loaded."""
get_app().state.engine = None
app.state.engine = None
response = client.get("/health")
assert response.status_code == 200
data = response.json()
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
async def async_gen():
yield "Assistant reply"
get_app().state.engine = loaded_model
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
yield "cumulative1"
yield "cumulative2"
get_app().state.engine = loaded_model
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
async def async_gen():
yield "Assistant reply"
get_app().state.engine = loaded_model
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
yield "cumulative1"
yield "cumulative2"
get_app().state.engine = loaded_model
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
async def async_gen():
yield "Reply"
get_app().state.engine = loaded_model
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -165,7 +165,7 @@ def test_chat_completions_stop_sequence(client, loaded_model):
yield "X"
yield "world"
get_app().state.engine = loaded_model
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -191,7 +191,7 @@ def test_chat_completions_stop_sequence_stream(client, loaded_model):
yield "X"
yield "world"
get_app().state.engine = loaded_model
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",

View File

@ -1,355 +0,0 @@
import tempfile
import pytest
import torch
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model import AutoRegressiveLM
from astrai.model.components.linear import Linear
from astrai.model.components.lora import (
LoRAConfig,
LoRALinear,
_collect_lora_info,
_get_lora_count,
inject_lora,
load_lora,
merge_lora,
save_lora,
)
MODEL_KWARGS = dict(
vocab_size=1000,
dim=64,
n_heads=4,
n_kv_heads=2,
dim_ffn=128,
n_layers=2,
max_len=32,
norm_eps=1e-5,
)
def _make_model(**kwargs):
kw = {**MODEL_KWARGS, **kwargs}
config = AutoRegressiveLMConfig(**kw)
model = AutoRegressiveLM(config)
model.eval()
return model
def test_loralinear_init():
base = Linear(64, 128)
lora = LoRALinear(base, r=8, alpha=16)
assert lora.weight is base.weight
assert not lora.weight.requires_grad
assert lora.lora_A.shape == (8, 64)
assert lora.lora_B.shape == (128, 8)
assert lora.scaling == 2.0
assert not lora._merged
assert lora.lora_A.requires_grad
assert lora.lora_B.requires_grad
def test_loralinear_forward_init_zero_delta():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
base_out = base(x)
lora_out = lora(x)
assert torch.allclose(base_out, lora_out)
def test_loralinear_forward_with_delta():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
base_out = base(x)
with torch.no_grad():
lora.lora_B.fill_(1.0)
lora_out = lora(x)
assert not torch.allclose(base_out, lora_out)
def test_loralinear_merge():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
with torch.no_grad():
lora.lora_B.fill_(1.0)
out_before = lora(x).clone()
lora.merge()
out_after = lora(x)
torch.testing.assert_close(out_before, out_after)
assert lora._merged
assert not hasattr(lora, "lora_A")
def test_loralinear_merge_is_idempotent():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
lora = LoRALinear(base, r=2, alpha=2)
with torch.no_grad():
lora.lora_B.fill_(1.0)
lora.merge()
lora.merge()
def test_inject_lora_default_target():
model = _make_model()
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
inject_lora(model, r=4, alpha=8)
lora_count = _get_lora_count(model)
assert lora_count > 0
assert lora_count < n_before
def test_inject_lora_ffn():
model = _make_model()
from astrai.model.components.lora import TARGET_MODULES_FFN
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
assert _get_lora_count(model) > 0
def test_inject_lora_returns_config():
model = _make_model()
cfg = inject_lora(model, r=8, alpha=32)
assert isinstance(cfg, LoRAConfig)
assert cfg.r == 8
assert cfg.alpha == 32
def test_inject_lora_no_matching_targets_warns(caplog):
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
assert "No LoRA layers injected" in caplog.text
def test_inject_lora_preserves_base_output():
model = _make_model()
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_before = model(x)["logits"].clone()
inject_lora(model, r=4, alpha=8)
with torch.no_grad():
out_after = model(x)["logits"]
torch.testing.assert_close(out_before, out_after)
def test_inject_lora_does_not_reinject():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
first_count = _get_lora_count(model)
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
assert _get_lora_count(model) == first_count
def test_inject_lora_adds_new_modules():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
first = _get_lora_count(model)
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
assert _get_lora_count(model) > first
def test_inject_lora_on_mla_model():
model = _make_model(
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
)
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
assert _get_lora_count(model) > 0
def test_inject_lora_on_moe_model():
model = _make_model(
ffn_type="moe",
n_routed_experts=4,
n_shared_experts=1,
n_activated_experts=2,
dim_ffn=32,
)
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
assert _get_lora_count(model) > 0
def test_state_dict_key_format():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
sd = model.state_dict()
assert "layers.0.attention.q_proj.weight" in sd
assert "layers.0.attention.q_proj.lora_A" in sd
assert "layers.0.attention.q_proj.lora_B" in sd
def test_only_lora_params_trainable():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
for name, param in model.named_parameters():
if isinstance(name.split(".")[-1], str) and "lora" in name:
assert param.requires_grad, f"lora param should be trainable: {name}"
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
assert not param.requires_grad, f"injected weight should be frozen: {name}"
def test_state_dict_after_inject_consistent_with_original():
model = _make_model()
sd_before = {k: v for k, v in model.state_dict().items()}
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
sd_after = model.state_dict()
# original keys unchanged
for k in sd_before:
assert k in sd_after
assert sd_before[k].shape == sd_after[k].shape
# new lora keys present
lora_keys = [k for k in sd_after if "lora" in k]
assert len(lora_keys) > 0
def test_save_load_roundtrip():
model = _make_model()
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_src = model(x)["logits"].clone()
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
load_lora(model2, tmpdir)
with torch.no_grad():
out_dst = model2(x)["logits"]
torch.testing.assert_close(out_src, out_dst)
def test_save_after_merge_raises():
model = _make_model()
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
merge_lora(model)
tmpdir2 = tempfile.mkdtemp()
with pytest.raises(RuntimeError, match="No LoRA parameters"):
save_lora(model, tmpdir2, cfg)
def test_load_lora_on_already_injected():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
# load onto already-injected model
load_lora(model2, tmpdir)
assert _get_lora_count(model2) > 0
def test_load_lora_mismatched_r_raises():
model = _make_model()
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
with pytest.raises(RuntimeError, match="size mismatch"):
load_lora(model2, tmpdir) # strict=False, only lora keys
def test_merge_preserves_output():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_before = model(x)["logits"].clone()
merge_lora(model)
with torch.no_grad():
out_after = model(x)["logits"]
torch.testing.assert_close(out_before, out_after)
def test_merge_no_lora_warns(caplog):
model = _make_model()
merge_lora(model)
assert "No LoRA layers to merge" in caplog.text
def test_collect_lora_info():
model = _make_model()
info = _collect_lora_info(model)
assert "q_proj" in info
assert "o_proj" in info
assert "q_proj" in info # each layer has one

View File

@ -1,5 +1,3 @@
import os
import pytest
import torch
from torch.utils.data import Dataset
@ -27,7 +25,7 @@ class TrainerDataset(Dataset):
def create_train_config(
model_fn,
model: torch.nn.Module,
dataset: Dataset,
test_dir: str,
device: str,
@ -43,7 +41,7 @@ def create_train_config(
"""Factory function to create common TrainConfig for tests.
Args:
model_fn: Model factory (callable returning nn.Module)
model: The model to train
dataset: Training dataset
test_dir: Checkpoint directory
device: Device type ("cuda" or "cpu")
@ -70,12 +68,11 @@ def create_train_config(
return TrainConfig(
strategy=strategy,
model_fn=model_fn,
model=model,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=test_dir,
log_dir=os.path.join(test_dir, "logs"),
n_epoch=n_epoch,
batch_per_device=batch_per_device,
ckpt_interval=ckpt_interval,

View File

@ -1,5 +1,3 @@
import os
import torch
from astrai.config.train_config import TrainConfig
@ -106,13 +104,12 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
)
train_config = TrainConfig(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
n_epoch=1,
batch_per_device=2,
ckpt_interval=3,
@ -140,13 +137,12 @@ def test_callback_integration(base_test_env, random_dataset):
)
train_config = TrainConfig(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
n_epoch=1,
batch_per_device=2,
ckpt_interval=3,

View File

@ -4,6 +4,7 @@ import numpy as np
import torch
from astrai.config.train_config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer
@ -23,10 +24,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
strategy="seq",
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
n_epoch=2,
batch_per_device=2,
ckpt_interval=1,
@ -38,20 +38,17 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
trainer = Trainer(train_config)
# Should handle early stopping gracefully
checkpoint = None
try:
trainer.train()
checkpoint = trainer.train()
except Exception:
# Handle any exceptions
pass
# Resume from latest checkpoint
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
trainer = Trainer(train_config)
trainer.train(resume_dir=load_dir)
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
# Verify checkpoint was saved at expected iteration
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
import json
with open(os.path.join(load_dir, "meta.json")) as f:
meta = json.load(f)
assert meta["iteration"] == 10
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10

View File

@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
for batch_per_device in batch_sizes:
train_config = train_config_factory(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
for config in small_batch_configs:
train_config = train_config_factory(
model_fn=lambda: base_test_env["model"],
model=base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],