Compare commits
No commits in common. "main" and "v1.3.6" have entirely different histories.
|
|
@ -82,7 +82,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
|
|
@ -109,8 +108,8 @@ Full reference at [Parameter Guide](assets/docs/params.md).
|
|||
```bash
|
||||
python scripts/tools/generate.py \
|
||||
--param_path /path/to/model \
|
||||
--input_json_file /path/to/input.jsonl \
|
||||
--output_json_file /path/to/output.jsonl
|
||||
--input_json_file /path/to/input.json \
|
||||
--output_json_file /path/to/output.json
|
||||
```
|
||||
|
||||
#### Docker
|
||||
|
|
@ -225,7 +224,6 @@ Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6y
|
|||
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
||||
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
||||
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
||||
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
|
||||
|
||||
### Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -88,7 +88,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
|
|
@ -115,8 +114,8 @@ nohup python scripts/tools/train.py \
|
|||
```bash
|
||||
python scripts/tools/generate.py \
|
||||
--param_path /path/to/model \
|
||||
--input_json_file /path/to/input.jsonl \
|
||||
--output_json_file /path/to/output.jsonl
|
||||
--input_json_file /path/to/input.json \
|
||||
--output_json_file /path/to/output.json
|
||||
```
|
||||
|
||||
#### Docker
|
||||
|
|
@ -231,7 +230,6 @@ python scripts/demo/generate_ar.py
|
|||
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
||||
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
||||
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
||||
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
|
||||
|
||||
### 贡献
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ classDiagram
|
|||
class BaseConfig {
|
||||
+to_dict() Dict
|
||||
+from_dict(d) Self
|
||||
+from_json(path) Self
|
||||
+to_json(path)
|
||||
}
|
||||
|
||||
class BaseModelConfig {
|
||||
|
|
@ -19,43 +17,41 @@ classDiagram
|
|||
}
|
||||
|
||||
class AutoRegressiveLMConfig {
|
||||
+Optional[int] vocab_size
|
||||
+Optional[int] dim
|
||||
+Optional[int] n_layers
|
||||
+Optional[float] norm_eps
|
||||
+Optional[int] dim_ffn
|
||||
+Optional[bool] tie_weight
|
||||
+Optional[dict] rope_scaling
|
||||
+Optional[int] max_len
|
||||
+Optional[float] rope_theta
|
||||
+int vocab_size
|
||||
+int dim
|
||||
+int n_layers
|
||||
+float norm_eps
|
||||
+int dim_ffn
|
||||
+bool tie_weight
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+str attn_type
|
||||
+Optional[int] n_heads
|
||||
+Optional[int] n_kv_heads
|
||||
+Optional[bool] use_qk_norm
|
||||
+Optional[bool] use_gated_attention
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Optional[int] kv_lora_rank
|
||||
+Optional[int] qk_nope_head_dim
|
||||
+Optional[int] qk_rope_head_dim
|
||||
+str ffn_type
|
||||
+Optional[int] n_routed_experts
|
||||
+Optional[int] n_shared_experts
|
||||
+Optional[int] n_activated_experts
|
||||
+int n_routed_experts
|
||||
+int n_shared_experts
|
||||
+int n_activated_experts
|
||||
+Optional[str] topk_method
|
||||
}
|
||||
|
||||
class EncoderConfig {
|
||||
+Optional[int] vocab_size
|
||||
+Optional[int] dim
|
||||
+Optional[int] n_layers
|
||||
+Optional[float] norm_eps
|
||||
+Optional[int] dim_ffn
|
||||
+Optional[int] max_len
|
||||
+Optional[float] rope_theta
|
||||
+Optional[int] n_heads
|
||||
+Optional[int] n_kv_heads
|
||||
+Optional[bool] use_qk_norm
|
||||
+Optional[bool] use_gated_attention
|
||||
+Optional[dict] rope_scaling
|
||||
+int vocab_size
|
||||
+int dim
|
||||
+int n_layers
|
||||
+float norm_eps
|
||||
+int dim_ffn
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Optional[str] pooling_type
|
||||
+Optional[bool] normalize_embeddings
|
||||
}
|
||||
|
|
@ -66,40 +62,8 @@ classDiagram
|
|||
+load(raw) BaseConfig
|
||||
}
|
||||
|
||||
class InputConfig {
|
||||
+str type
|
||||
+str messages_key
|
||||
+str prompt_key
|
||||
+str response_key
|
||||
+str text_key
|
||||
}
|
||||
|
||||
class ProcessingConfig {
|
||||
+int max_seq_len
|
||||
+int min_chars
|
||||
+int max_chars
|
||||
+bool deduplicate
|
||||
+Optional[int] max_items
|
||||
}
|
||||
|
||||
class OutputConfig {
|
||||
+Optional[str] domain_key
|
||||
+str storage_format
|
||||
+int max_tokens_per_shard
|
||||
}
|
||||
|
||||
class PipelineConfig {
|
||||
+int version
|
||||
+InputConfig input
|
||||
+dict mask
|
||||
+str mask_default
|
||||
+ProcessingConfig preprocessing
|
||||
+OutputConfig output
|
||||
+from_dict(d) Self
|
||||
}
|
||||
|
||||
class TrainConfig {
|
||||
+Callable[[], nn.Module] model_fn
|
||||
+nn.Module model
|
||||
+str strategy
|
||||
+Dataset dataset
|
||||
+Callable optimizer_fn
|
||||
|
|
@ -116,7 +80,6 @@ classDiagram
|
|||
+str log_dir
|
||||
+int log_interval
|
||||
+List[str] metrics
|
||||
+Optional[LoRAConfig] lora
|
||||
+int random_seed
|
||||
+int num_workers
|
||||
+Optional[int] prefetch_factor
|
||||
|
|
@ -125,12 +88,12 @@ classDiagram
|
|||
+str backend
|
||||
+str master_addr
|
||||
+str master_port
|
||||
+Callable parallel_wrapper
|
||||
+Callable state_dict_fn
|
||||
+str start_method
|
||||
+str device_type
|
||||
+Optional[Dataset] val_dataset
|
||||
+int val_step
|
||||
+str parallel_mode
|
||||
+dict executor_kwargs
|
||||
+dict extra_kwargs
|
||||
+validate()
|
||||
}
|
||||
|
|
@ -141,8 +104,8 @@ classDiagram
|
|||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+Optional[Store] storage
|
||||
+load(load_path, storage_type)
|
||||
+Optional[BaseStorage] storage
|
||||
+load(load_path, storage_type, tokenizer)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
}
|
||||
|
|
@ -163,25 +126,38 @@ classDiagram
|
|||
+__getitem__(index) Dict
|
||||
}
|
||||
|
||||
class Store {
|
||||
+Dict[str, List[Tensor]] _data
|
||||
+Dict[str, List[int]] _cum
|
||||
+int _length
|
||||
class BaseSegmentFetcher {
|
||||
+List[Tensor] segments
|
||||
+List[int] cum_lengths
|
||||
+int total_length
|
||||
+fetch_data(begin_idx, end_idx) Tensor
|
||||
}
|
||||
|
||||
class BaseStorage {
|
||||
+MultiSegmentFetcher _fetcher
|
||||
+keys (property)
|
||||
+load(path)
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys)
|
||||
+__len__()
|
||||
-_fetch_key(key, begin, end) Tensor
|
||||
-_normalize(raw)
|
||||
}
|
||||
|
||||
class H5Store {
|
||||
+load(path)
|
||||
class H5Storage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
}
|
||||
|
||||
class MmapStore {
|
||||
+List _mmap_refs
|
||||
+load(path)
|
||||
class JSONStorage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict multi_fetchers
|
||||
+List multi_keys
|
||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||
+fetch_data(begin_idx, end_idx) Dict
|
||||
}
|
||||
|
||||
class ResumableDistributedSampler {
|
||||
|
|
@ -189,17 +165,17 @@ classDiagram
|
|||
+int iter
|
||||
}
|
||||
|
||||
class StoreFactory {
|
||||
class StorageFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(storage_type) Store
|
||||
+create(storage_type) BaseStorage
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(train_type, window_size, stride) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -210,9 +186,8 @@ classDiagram
|
|||
+int iteration
|
||||
+dict extra
|
||||
+dict meta
|
||||
+dict config
|
||||
+save(save_dir)
|
||||
+load(save_dir, broadcast) Checkpoint
|
||||
+load(save_dir) Checkpoint
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -220,8 +195,8 @@ classDiagram
|
|||
class AutoModel {
|
||||
+BaseModelConfig config
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+get_component_class(name) Type
|
||||
+register(model_type) decorator
|
||||
+get_component_class(model_type) Type
|
||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||
+save_pretrained(save_directory)
|
||||
+to(*args, **kwargs) Self
|
||||
|
|
@ -235,7 +210,7 @@ classDiagram
|
|||
+RMSNorm norm
|
||||
+Linear lm_head
|
||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
||||
+load_state_dict(state_dict, strict, assign)
|
||||
+load_state_dict(state_dict)
|
||||
+state_dict()
|
||||
}
|
||||
|
||||
|
|
@ -260,7 +235,6 @@ classDiagram
|
|||
}
|
||||
|
||||
class GQA {
|
||||
+int dim
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+int head_dim
|
||||
|
|
@ -275,7 +249,6 @@ classDiagram
|
|||
}
|
||||
|
||||
class MLA {
|
||||
+int dim
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+int head_dim
|
||||
|
|
@ -284,13 +257,11 @@ classDiagram
|
|||
+int qk_rope_head_dim
|
||||
+int n_rep
|
||||
+int layer_id
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||
+Linear o_proj
|
||||
+Linear gate # only if use_gated_attention
|
||||
+RMSNorm kv_norm
|
||||
+RMSNorm q_norm, k_norm # only if use_qk_norm
|
||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||
}
|
||||
|
||||
|
|
@ -336,7 +307,6 @@ classDiagram
|
|||
+int dim
|
||||
+int max_len
|
||||
+float base
|
||||
+Optional[Dict] rope_scaling
|
||||
+forward(x, position_ids=None) Tensor
|
||||
}
|
||||
|
||||
|
|
@ -346,42 +316,13 @@ classDiagram
|
|||
}
|
||||
}
|
||||
|
||||
namespace preprocessing {
|
||||
class BaseMaskBuilder {
|
||||
<<abstract>>
|
||||
+build(item, config, tokenizer) Optional[dict]
|
||||
}
|
||||
|
||||
class ChatMaskBuilder {
|
||||
+build(item, config, tokenizer) Optional[dict]
|
||||
}
|
||||
|
||||
class InstructionMaskBuilder {
|
||||
+build(item, config, tokenizer) Optional[dict]
|
||||
}
|
||||
|
||||
class TextMaskBuilder {
|
||||
+build(item, config, tokenizer) Optional[dict]
|
||||
}
|
||||
|
||||
class Pipeline {
|
||||
+PipelineConfig config
|
||||
+List[str] paths
|
||||
+str output_dir
|
||||
+str tokenizer_path
|
||||
+BaseMaskBuilder mask_builder
|
||||
+transform(item) Optional[dict]
|
||||
+run()
|
||||
}
|
||||
}
|
||||
|
||||
namespace tokenize {
|
||||
class AutoTokenizer {
|
||||
+vocab_size int
|
||||
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List
|
||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
|
||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||
+set_chat_template(template)
|
||||
+load(path)
|
||||
+from_pretrained(path) AutoTokenizer
|
||||
|
|
@ -389,7 +330,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class ChatTemplate {
|
||||
+str template_str
|
||||
+String template_str
|
||||
+render(messages, system_prompt, **extra_variables) str
|
||||
+from_string(template) ChatTemplate
|
||||
}
|
||||
|
|
@ -409,32 +350,24 @@ classDiagram
|
|||
+create(name, *args, **kwargs) T
|
||||
+list_registered() list
|
||||
}
|
||||
|
||||
class MaskBuilderFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(input_type, config, tokenizer) BaseMaskBuilder
|
||||
}
|
||||
}
|
||||
|
||||
namespace trainer {
|
||||
class Trainer {
|
||||
+TrainConfig train_config
|
||||
+List[TrainCallback] callbacks
|
||||
+train(resume_dir)
|
||||
-_get_default_callbacks() List[TrainCallback]
|
||||
+train(checkpoint)
|
||||
+_get_default_callbacks() List[TrainCallback]
|
||||
}
|
||||
|
||||
class TrainContext {
|
||||
+nn.Module model
|
||||
+BaseStrategy strategy
|
||||
+DataLoader dataloader
|
||||
+OptimizerProtocol optimizer
|
||||
+SchedulerProtocol scheduler
|
||||
+Optimizer optimizer
|
||||
+LRScheduler scheduler
|
||||
+Checkpoint checkpoint
|
||||
+TrainConfig config
|
||||
+dict model_config
|
||||
+BaseExecutor executor
|
||||
+int epoch
|
||||
+int iteration
|
||||
+float loss
|
||||
|
|
@ -447,17 +380,13 @@ classDiagram
|
|||
|
||||
class TrainContextBuilder {
|
||||
+TrainConfig config
|
||||
+with_resume_dir(resume_dir) TrainContextBuilder
|
||||
+with_checkpoint(checkpoint) TrainContextBuilder
|
||||
+build() TrainContext
|
||||
}
|
||||
|
||||
class BaseStrategy {
|
||||
+Callable model
|
||||
+Optional[BaseExecutor] executor
|
||||
+Optional[Callable] model_fn
|
||||
+dict extra_kwargs
|
||||
+Union[Callable, nn.Module] model
|
||||
+str device
|
||||
+__call__(batch) Tensor
|
||||
+compute_loss(batch) Tensor
|
||||
}
|
||||
|
||||
|
|
@ -498,8 +427,6 @@ classDiagram
|
|||
class BaseScheduler {
|
||||
+get_lr() List[float]
|
||||
+step()
|
||||
+state_dict() dict
|
||||
+load_state_dict(d)
|
||||
}
|
||||
|
||||
class SchedulerFactory {
|
||||
|
|
@ -511,7 +438,6 @@ classDiagram
|
|||
class CosineScheduler {
|
||||
+int warmup_steps
|
||||
+int lr_decay_steps
|
||||
+int total_steps
|
||||
+float min_rate
|
||||
}
|
||||
|
||||
|
|
@ -528,15 +454,16 @@ classDiagram
|
|||
+on_train_end(context)
|
||||
+on_epoch_begin(context)
|
||||
+on_epoch_end(context)
|
||||
+on_step_begin(context)
|
||||
+on_step_end(context)
|
||||
+on_batch_begin(context)
|
||||
+on_batch_end(context)
|
||||
+on_optimizer_step(context)
|
||||
+on_error(context)
|
||||
}
|
||||
|
||||
class GradientClippingCallback {
|
||||
+float max_grad_norm
|
||||
+on_optimizer_step(context)
|
||||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class GradientCheckpointingCallback {
|
||||
|
|
@ -549,12 +476,16 @@ classDiagram
|
|||
+str save_dir
|
||||
+int interval
|
||||
+bool weight_only
|
||||
+Callable state_dict_fn
|
||||
+Callable save_extra_fn
|
||||
-_save_checkpoint(context)
|
||||
+Callable load_extra_fn
|
||||
+_save_checkpoint(context)
|
||||
+on_train_begin(context)
|
||||
+on_batch_end(context)
|
||||
+on_train_end(context)
|
||||
+on_error(context)
|
||||
+save_extra(context) dict$
|
||||
+save_extra(context)$
|
||||
+load_extra(extra, context)$
|
||||
}
|
||||
|
||||
class ProgressBarCallback {
|
||||
|
|
@ -567,7 +498,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class MetricLoggerCallback {
|
||||
+Path log_dir
|
||||
+str log_dir
|
||||
+int save_interval
|
||||
+int log_interval
|
||||
+List[str] metrics
|
||||
|
|
@ -577,8 +508,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class ValidationCallback {
|
||||
-_run_validation(context)
|
||||
+on_optimizer_step(context)
|
||||
+_run_validation(context)
|
||||
+on_step_end(context)
|
||||
}
|
||||
|
||||
class CallbackFactory {
|
||||
|
|
@ -591,12 +522,7 @@ classDiagram
|
|||
+float lr
|
||||
+float momentum
|
||||
+float weight_decay
|
||||
+bool nesterov
|
||||
+int ns_steps
|
||||
+Optional[float] adamw_lr
|
||||
+tuple adamw_betas
|
||||
+float adamw_eps
|
||||
+float adamw_wd
|
||||
+step(closure) Optional[float]
|
||||
}
|
||||
}
|
||||
|
|
@ -617,8 +543,6 @@ classDiagram
|
|||
+AutoModel model
|
||||
+AutoTokenizer tokenizer
|
||||
+KVCache page_cache
|
||||
+Optional[str] device
|
||||
+Optional[torch.dtype] dtype
|
||||
+execute_prefill(tasks, prompt_len, start_pos)
|
||||
+execute_decode(tasks) List[int]
|
||||
}
|
||||
|
|
@ -630,9 +554,7 @@ classDiagram
|
|||
+bool _running
|
||||
+Thread _loop_thread
|
||||
+int max_seq_len
|
||||
+str device
|
||||
+torch.dtype dtype
|
||||
+add_task(prompt, **kwargs) str
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
+remove_task(task_id)
|
||||
+start()
|
||||
+stop()
|
||||
|
|
@ -710,7 +632,7 @@ classDiagram
|
|||
class Task {
|
||||
+str task_id
|
||||
+List prompt_ids
|
||||
+Optional[int] max_tokens
|
||||
+int max_tokens
|
||||
+float temperature
|
||||
+float top_p
|
||||
+int top_k
|
||||
|
|
@ -719,8 +641,8 @@ classDiagram
|
|||
+int input_tokens
|
||||
+int output_tokens
|
||||
+float arrival_time
|
||||
+Optional[float] finish_time
|
||||
+Optional[Callable] stream_callback
|
||||
+float finish_time
|
||||
+Callable stream_callback
|
||||
+int next_pos
|
||||
+is_finished(stop_ids) bool
|
||||
}
|
||||
|
|
@ -735,24 +657,15 @@ classDiagram
|
|||
|
||||
class TaskManager {
|
||||
+AutoTokenizer tokenizer
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
+Deque waiting_queue
|
||||
+List active_tasks
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
+add_task(prompt, **kwargs) str
|
||||
+remove_task(task_id) List[Task]
|
||||
+remove_finished_tasks(stop_ids) List[Task]
|
||||
+pull_candidates(n) List[Task]
|
||||
+activate(task)
|
||||
+return_to_waiting(tasks)
|
||||
+get_active_tasks() List[Task]
|
||||
+has_work() bool
|
||||
+wait_for_tasks(timeout)
|
||||
+get_waiting_tasks() List[Task]
|
||||
+clear_queues()
|
||||
+wake()
|
||||
+get_stats() Dict
|
||||
}
|
||||
|
||||
class GenerationRequest {
|
||||
|
|
@ -831,65 +744,56 @@ classDiagram
|
|||
+str model
|
||||
+List[AnthropicMessage] messages
|
||||
+Optional[str] system
|
||||
+Optional[float] temperature
|
||||
+Optional[float] top_p
|
||||
+Optional[int] top_k
|
||||
+float temperature
|
||||
+float top_p
|
||||
+int top_k
|
||||
+int max_tokens
|
||||
+Optional[bool] stream
|
||||
+bool stream
|
||||
+Optional[List[str]] stop_sequences
|
||||
}
|
||||
|
||||
class ResponseBuilder {
|
||||
<<abstract>>
|
||||
+prepare(request, tokenizer) Tuple[str, GenContext, List[str]]
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class OpenAIResponseBuilder {
|
||||
+prepare(request, tokenizer) Tuple
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class AnthropicResponseBuilder {
|
||||
+prepare(request, tokenizer) Tuple
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class ProtocolHandler {
|
||||
<<abstract>>
|
||||
+request
|
||||
+engine
|
||||
+builder: ResponseBuilder
|
||||
+async handle() Union[StreamingResponse, Dict]
|
||||
-_handle_stream(agen, ctx, stop_sequences) StreamingResponse
|
||||
-async _handle_non_stream(agen, ctx, stop_sequences) Dict
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+get_stop_sequences() List[str]
|
||||
+create_stop_checker() StopChecker
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_stream_token(ctx, token) str
|
||||
+format_stream_end(ctx) List[str]
|
||||
+format_non_stream_response(ctx, content) Dict
|
||||
+handle() Union[StreamingResponse, Dict]
|
||||
}
|
||||
|
||||
class OpenAIHandler {
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
}
|
||||
|
||||
class AnthropicHandler {
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
}
|
||||
|
||||
class StopChecker {
|
||||
+__init__(sequences)
|
||||
+has_sequences (property) bool
|
||||
+check(text) Optional[str]
|
||||
+trim(text, matched) str
|
||||
}
|
||||
|
||||
class GenContext {
|
||||
class StreamContext {
|
||||
+str resp_id
|
||||
+int created
|
||||
+str model
|
||||
+int prompt_tokens
|
||||
+int completion_tokens
|
||||
}
|
||||
|
||||
class StopInfo {
|
||||
+Optional[str] matched
|
||||
+str body
|
||||
+str yielded
|
||||
+str accumulated
|
||||
+Optional[str] stop_matched
|
||||
+str last_yield_trimmed
|
||||
}
|
||||
|
||||
class app {
|
||||
|
|
@ -898,87 +802,15 @@ classDiagram
|
|||
}
|
||||
}
|
||||
|
||||
namespace protocols {
|
||||
class OptimizerProtocol {
|
||||
<<protocol>>
|
||||
+step(closure)
|
||||
+zero_grad()
|
||||
+state_dict() dict
|
||||
+load_state_dict(d)
|
||||
}
|
||||
|
||||
class SchedulerProtocol {
|
||||
<<protocol>>
|
||||
+step()
|
||||
+state_dict() dict
|
||||
+load_state_dict(d)
|
||||
+get_last_lr()
|
||||
}
|
||||
}
|
||||
|
||||
namespace parallel {
|
||||
class setup {
|
||||
class Functions {
|
||||
<<module>>
|
||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) contextmanager
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
+get_current_device() str
|
||||
+get_world_size() int
|
||||
+get_rank() int
|
||||
+only_on_rank(rank, sync=False) decorator
|
||||
}
|
||||
|
||||
class GradientState {
|
||||
+int num_steps
|
||||
+sync_gradients (property) bool
|
||||
}
|
||||
|
||||
class AccumOptimizer {
|
||||
+Optimizer optimizer
|
||||
+GradientState gradient_state
|
||||
+param_groups (property)
|
||||
+step(closure)
|
||||
+zero_grad()
|
||||
+state_dict() dict
|
||||
+load_state_dict(d)
|
||||
}
|
||||
|
||||
class AccumScheduler {
|
||||
+LRScheduler scheduler
|
||||
+GradientState gradient_state
|
||||
+step()
|
||||
+state_dict() dict
|
||||
+load_state_dict(d)
|
||||
+get_last_lr()
|
||||
}
|
||||
|
||||
class BaseExecutor {
|
||||
+GradientState gradient_state
|
||||
+prepare(model, optimizer, dataloader, scheduler) tuple
|
||||
+accumulate(model) context manager
|
||||
+backward(loss)
|
||||
+unwrap_model(model) dict
|
||||
+sync_gradients (property) bool
|
||||
+grad_accum_steps (property) int
|
||||
}
|
||||
|
||||
class NoneExecutor {
|
||||
}
|
||||
|
||||
class DDPExecutor {
|
||||
-_prepare_model(model) nn.Module
|
||||
-_no_sync(model) context manager
|
||||
+unwrap_model(model) dict
|
||||
}
|
||||
|
||||
class FSDPExecutor {
|
||||
-_prepare_model(model) nn.Module
|
||||
+unwrap_model(model) dict
|
||||
}
|
||||
|
||||
class ExecutorFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(parallel_mode, **kwargs) BaseExecutor
|
||||
+only_on_rank(rank, sync) decorator
|
||||
}
|
||||
|
||||
class ParallelModel {
|
||||
|
|
@ -988,25 +820,11 @@ classDiagram
|
|||
}
|
||||
|
||||
class ColumnParallelLinear {
|
||||
+int in_features
|
||||
+int out_features
|
||||
+int out_features_per_rank
|
||||
+bool gather_results
|
||||
+Parameter weight
|
||||
+Optional[Parameter] bias
|
||||
+forward(x) Tensor
|
||||
+load_state_dict(state_dict)
|
||||
}
|
||||
|
||||
class RowParallelLinear {
|
||||
+int in_features
|
||||
+int out_features
|
||||
+int in_features_per_rank
|
||||
+bool reduce_results
|
||||
+Parameter weight
|
||||
+Optional[Parameter] bias
|
||||
+forward(x) Tensor
|
||||
+load_state_dict(state_dict)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1024,13 +842,12 @@ classDiagram
|
|||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
TrainCallback <|-- ValidationCallback
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
BaseDataset <|-- GRPODataset
|
||||
Store <|-- H5Store
|
||||
Store <|-- MmapStore
|
||||
BaseStorage <|-- H5Storage
|
||||
BaseStorage <|-- JSONStorage
|
||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||
BaseSamplingStrategy <|-- TopKStrategy
|
||||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
|
|
@ -1041,10 +858,6 @@ classDiagram
|
|||
AutoModel <|-- EmbeddingEncoder
|
||||
BaseConfig <|-- BaseModelConfig
|
||||
BaseConfig <|-- TrainConfig
|
||||
BaseConfig <|-- InputConfig
|
||||
BaseConfig <|-- ProcessingConfig
|
||||
BaseConfig <|-- OutputConfig
|
||||
BaseConfig <|-- PipelineConfig
|
||||
BaseModelConfig <|-- AutoRegressiveLMConfig
|
||||
BaseModelConfig <|-- EncoderConfig
|
||||
BaseFactory <|-- AutoModel
|
||||
|
|
@ -1054,23 +867,18 @@ classDiagram
|
|||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
BaseFactory <|-- CallbackFactory
|
||||
BaseFactory <|-- StoreFactory
|
||||
BaseFactory <|-- ExecutorFactory
|
||||
BaseFactory <|-- StorageFactory
|
||||
BaseFactory <|-- ConfigFactory
|
||||
BaseFactory <|-- MaskBuilderFactory
|
||||
BaseExecutor <|-- NoneExecutor
|
||||
BaseExecutor <|-- DDPExecutor
|
||||
BaseExecutor <|-- FSDPExecutor
|
||||
ResponseBuilder <|-- OpenAIResponseBuilder
|
||||
ResponseBuilder <|-- AnthropicResponseBuilder
|
||||
BaseMaskBuilder <|-- ChatMaskBuilder
|
||||
BaseMaskBuilder <|-- InstructionMaskBuilder
|
||||
BaseMaskBuilder <|-- TextMaskBuilder
|
||||
TrainCallback <|-- ValidationCallback
|
||||
ProtocolHandler <|-- OpenAIHandler
|
||||
ProtocolHandler <|-- AnthropicHandler
|
||||
|
||||
%% --- Composition (strong ownership, part destroyed with whole) ---
|
||||
KVCache *-- PagePool
|
||||
KVCache *-- Storage
|
||||
KVCache *-- TaskTable
|
||||
PagePool *-- Allocator
|
||||
PagePool *-- PrefixCache
|
||||
InferenceEngine *-- InferenceScheduler
|
||||
InferenceScheduler *-- KVCache
|
||||
InferenceScheduler *-- Executor
|
||||
|
|
@ -1084,31 +892,21 @@ classDiagram
|
|||
DecoderBlock *-- RMSNorm
|
||||
ChatCompletionRequest *-- ChatMessage
|
||||
MessagesRequest *-- AnthropicMessage
|
||||
AutoTokenizer *-- ChatTemplate
|
||||
BaseFactory *-- Registry
|
||||
BaseExecutor *-- GradientState
|
||||
AccumOptimizer o-- GradientState
|
||||
AccumScheduler o-- GradientState
|
||||
|
||||
%% --- Aggregation (weak ownership) ---
|
||||
AutoModel o-- BaseModelConfig
|
||||
AutoTokenizer o-- ChatTemplate
|
||||
PagePool o-- Allocator
|
||||
PagePool o-- PrefixCache
|
||||
Trainer o-- TrainCallback
|
||||
TrainContext o-- BaseStrategy
|
||||
TrainContext o-- BaseScheduler
|
||||
TrainContext o-- Checkpoint
|
||||
TrainContext o-- BaseExecutor
|
||||
KvcacheView o-- Storage
|
||||
SamplingPipeline o-- BaseSamplingStrategy
|
||||
BaseDataset o-- Store
|
||||
Pipeline o-- PipelineConfig
|
||||
Pipeline o-- BaseMaskBuilder
|
||||
BaseDataset o-- BaseStorage
|
||||
|
||||
%% --- Dependency (uses temporarily) ---
|
||||
TrainConfig ..> BaseStrategy : selects
|
||||
PipelineConfig ..> MaskBuilderFactory : selects
|
||||
MaskBuilderFactory ..> BaseMaskBuilder : creates
|
||||
StrategyFactory ..> BaseStrategy : creates
|
||||
SchedulerFactory ..> BaseScheduler : creates
|
||||
DatasetFactory ..> BaseDataset : creates
|
||||
|
|
@ -1119,14 +917,10 @@ classDiagram
|
|||
FFNFactory ..> DeepSeekMoE : creates
|
||||
DecoderBlock ..> AttnFactory : uses
|
||||
DecoderBlock ..> FFNFactory : uses
|
||||
StoreFactory ..> H5Store : creates
|
||||
StoreFactory ..> MmapStore : creates
|
||||
StorageFactory ..> H5Storage : creates
|
||||
StorageFactory ..> JSONStorage : creates
|
||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||
ConfigFactory ..> EncoderConfig : creates
|
||||
ExecutorFactory ..> NoneExecutor : creates
|
||||
ExecutorFactory ..> DDPExecutor : creates
|
||||
ExecutorFactory ..> FSDPExecutor : creates
|
||||
TrainContextBuilder ..> ExecutorFactory : creates
|
||||
Trainer ..> TrainContextBuilder : uses
|
||||
TrainContextBuilder ..> TrainContext : creates
|
||||
Trainer ..> Functions : spawns
|
||||
|
|
@ -1137,10 +931,10 @@ classDiagram
|
|||
KVCache ..> KvcacheView : binds
|
||||
InferenceEngine ..> GenerationRequest : uses
|
||||
InferenceEngine ..> GenerateResult : creates
|
||||
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
|
||||
AnthropicResponseBuilder ..> MessagesRequest : receives
|
||||
OpenAIHandler ..> ChatCompletionRequest : receives
|
||||
AnthropicHandler ..> MessagesRequest : receives
|
||||
ProtocolHandler ..> StopChecker : creates
|
||||
ProtocolHandler ..> GenContext : creates
|
||||
ProtocolHandler ..> StreamContext : creates
|
||||
|
||||
%% --- Association (general usage) ---
|
||||
Trainer --> TrainConfig
|
||||
|
|
@ -1153,6 +947,8 @@ classDiagram
|
|||
Executor --> AutoModel
|
||||
Executor --> AutoTokenizer
|
||||
TaskManager --> AutoTokenizer
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||
ResumableDistributedSampler --> BaseDataset
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -1161,48 +957,43 @@ classDiagram
|
|||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
|
||||
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, ChatMaskBuilder, InstructionMaskBuilder, TextMaskBuilder, Pipeline, filter_by_length, dedup_signature | Declarative JSON-driven data preprocessing |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
|
||||
|
||||
## Design Patterns
|
||||
|
||||
| Pattern | Classes | Purpose |
|
||||
|---------|---------|---------|
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
||||
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
|
||||
| **Builder** | `TrainContextBuilder` | Chain-building training context |
|
||||
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
||||
| **Context** | `TrainContext` | Unified training state bag |
|
||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor`, `FSDPExecutor` | Gradient accumulation & model distribution |
|
||||
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||
|
||||
## Core Relationships
|
||||
|
||||
1. **Config → Training**: `TrainConfig` holds `model_fn`, `dataset`, `optimizer_fn`, `scheduler_fn`, `parallel_mode`, `executor_kwargs`
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
|
|
@ -5,21 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
|
|||
## Overview
|
||||
|
||||
```
|
||||
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
|
||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
|
||||
|
||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||
|
||||
```
|
||||
StoreFactory.create("h5") → H5Store
|
||||
StoreFactory.create("bin") → MmapStore
|
||||
StorageFactory.create("h5") → H5Storage
|
||||
StorageFactory.create("json") → JSONStorage
|
||||
```
|
||||
|
||||
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||
Both support shared memory via `.share_memory_()`.
|
||||
|
||||
## Data Keys by Training Type
|
||||
|
||||
|
|
@ -33,21 +33,14 @@ H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS pag
|
|||
## Dataset Architecture
|
||||
|
||||
```
|
||||
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
|
||||
→ BaseDataset.load(load_path, storage_type=None)
|
||||
→ detect_format(load_path)
|
||||
→ StoreFactory.create(storage_type)
|
||||
→ Store.load(load_path)
|
||||
→ H5Store._normalize() / MmapStore._normalize()
|
||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ get_index(idx) → [begin, end)
|
||||
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
|
||||
DatasetFactory.load(train_type, path, window_size, stride)
|
||||
→ StorageFactory.create(detect_format(path))
|
||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ sliding window [begin, end) via get_index(idx)
|
||||
```
|
||||
|
||||
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
|
||||
|
||||
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
|
||||
`window_size` = max input length, `stride` = step between consecutive samples.
|
||||
|
||||
## Sampler
|
||||
|
||||
|
|
@ -61,4 +54,4 @@ DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_typ
|
|||
|
||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
|
|
@ -12,16 +12,16 @@ RoPE is applied **before** KV cache write, not after — otherwise position enco
|
|||
|
||||
## KVCache System
|
||||
|
||||
Six classes (plus two helpers) working together:
|
||||
Six classes working together:
|
||||
|
||||
```
|
||||
KVCache (facade)
|
||||
├── PagePool orchestrates page allocation + prefix matching
|
||||
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
||||
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
||||
├── Allocator bitmask-based page allocator + ref-count + LRU eviction
|
||||
├── PrefixCache hash-based prefix matching (page_hash via rolling hash)
|
||||
├── PagePool orchestrates Allocator + PrefixCache
|
||||
├── TaskTable maps task_id → page_table + cached token count
|
||||
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
||||
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
||||
└── KvcacheView bundles Storage + page_table + total_len for attention layers
|
||||
```
|
||||
|
||||
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
||||
|
|
@ -40,33 +40,26 @@ KVCache (facade)
|
|||
## Sampling (Strategy Pattern)
|
||||
|
||||
```
|
||||
BaseSamplingStrategy (ABC)
|
||||
├── TemperatureStrategy
|
||||
├── TopKStrategy
|
||||
├── TopPStrategy
|
||||
└── SamplingPipeline
|
||||
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
||||
```
|
||||
|
||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||
`sample()` is a convenience shortcut for one-shot usage.
|
||||
|
||||
## Protocol Handlers (Strategy Pattern)
|
||||
## Protocol Handlers (Template Method)
|
||||
|
||||
```python
|
||||
class ProtocolHandler: # concrete orchestrator
|
||||
def __init__(self, request, engine, builder): ...
|
||||
async def handle(self):
|
||||
prompt, ctx, stops = builder.prepare(request, engine)
|
||||
class ProtocolHandler(ABC):
|
||||
def handle(self):
|
||||
ctx = StreamContext(...)
|
||||
agen = engine.generate_async(prompt, ...)
|
||||
if stream: self._handle_stream(agen, ctx, stops)
|
||||
else: return await self._handle_non_stream(agen, ctx, stops)
|
||||
if stream: self._handle_stream(agen, ctx)
|
||||
else: self._handle_non_stream(agen, ctx)
|
||||
```
|
||||
|
||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
|
||||
|
||||
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
||||
|
||||
Adding a protocol = one builder file, no handler subclassing needed.
|
||||
`OpenAIHandler` → `/v1/chat/completions`, `AnthropicHandler` → `/v1/messages`.
|
||||
|
||||
## Engine & GenerateResult
|
||||
|
||||
|
|
@ -74,9 +67,7 @@ Adding a protocol = one builder file, no handler subclassing needed.
|
|||
InferenceEngine
|
||||
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
||||
├── generate_with_request(req) → same
|
||||
├── generate_async(prompt, ...) → AsyncGenerator
|
||||
├── get_stats() → Dict
|
||||
└── shutdown()
|
||||
└── generate_async(prompt, ...) → AsyncGenerator
|
||||
```
|
||||
|
||||
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
||||
|
|
@ -103,14 +94,12 @@ Response:
|
|||
{
|
||||
"id": "chatcmpl-abc123",
|
||||
"object": "chat.completion",
|
||||
"created": 1717000000,
|
||||
"model": "astrai",
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||
}
|
||||
```
|
||||
|
||||
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
||||
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]`
|
||||
|
||||
### Anthropic
|
||||
|
||||
|
|
@ -127,10 +116,10 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
|||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||
| `top_k` | int | 50 | Top-k count |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (0.0–2.0) |
|
||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
|
||||
| `max_tokens` | Optional[int] | None | Max generation length |
|
||||
| `top_k` | int | 50 | Top-k count |
|
||||
| `max_tokens` | int | None | Max generation length |
|
||||
| `stream` | bool | False | Stream output |
|
||||
|
||||
## Engine API
|
||||
|
|
@ -145,8 +134,7 @@ engine.generate("Hello", stream=True) # -> Generator[str]
|
|||
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||
|
||||
# Async
|
||||
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
|
||||
print(token)
|
||||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
|
|
@ -53,9 +53,7 @@
|
|||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
||||
| `--device_type` | Device type | cuda |
|
||||
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||
|
||||
### Strategy-specific
|
||||
|
||||
|
|
@ -75,7 +73,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
|
|
@ -97,4 +94,4 @@ nohup python scripts/tools/train.py \
|
|||
|
||||
---
|
||||
|
||||
> Document Update Time: 2026-05-24
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
@ -1,346 +0,0 @@
|
|||
# Preprocessing Pipeline
|
||||
|
||||
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
|
||||
|
||||
## Philosophy
|
||||
|
||||
| Component | Responsibility |
|
||||
|-----------|---------------|
|
||||
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
||||
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
||||
|
||||
A single config file captures the entire pipeline, reusable and version-controllable.
|
||||
|
||||
## Config Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {}, // sections (single) or sources (multi)
|
||||
"mask": {}, // role → "train" | "mask"
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {},
|
||||
"output": {}
|
||||
}
|
||||
```
|
||||
|
||||
### Section Fields
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `field` | str | -- | JSONL key to read |
|
||||
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
|
||||
| `template` | bool | `false` | Apply `chat_template` per message |
|
||||
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
|
||||
|
||||
### Source Fields (multi-output mode)
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `sections` | list[dict] | -- | Same as single-output section list |
|
||||
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
|
||||
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### SFT Chat
|
||||
|
||||
Input JSONL:
|
||||
|
||||
```json
|
||||
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
|
||||
```
|
||||
|
||||
Config:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"sections": [
|
||||
{"field": "messages", "action": "$role", "template": true}
|
||||
]
|
||||
},
|
||||
"mask": {
|
||||
"system": "mask",
|
||||
"user": "mask",
|
||||
"assistant": "train"
|
||||
},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048
|
||||
},
|
||||
"output": {
|
||||
"storage_format": "bin",
|
||||
"dtype": {"loss_mask": "bool"}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Output keys: `sequence` (int32), `loss_mask` (bool)
|
||||
|
||||
### SFT Instruction
|
||||
|
||||
Input JSONL:
|
||||
|
||||
```json
|
||||
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||
```
|
||||
|
||||
Config:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"sections": [
|
||||
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
||||
{"field": "response", "action": "train"}
|
||||
]
|
||||
},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Output keys: `sequence`, `loss_mask`
|
||||
|
||||
### Pretrain
|
||||
|
||||
Input JSONL:
|
||||
|
||||
```json
|
||||
{"text": "Artificial Intelligence is a field of computer science..."}
|
||||
```
|
||||
|
||||
Config:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"sections": [
|
||||
{"field": "text", "action": "train"}
|
||||
]
|
||||
},
|
||||
"preprocessing": {
|
||||
"max_seq_len": 8192,
|
||||
"min_chars": 100
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Output keys: `sequence` (no `loss_mask` — all tokens trained)
|
||||
|
||||
### DPO
|
||||
|
||||
Input JSONL:
|
||||
|
||||
```json
|
||||
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
|
||||
```
|
||||
|
||||
Config:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"sources": {
|
||||
"chosen": {
|
||||
"sections": [
|
||||
{"field": "chosen", "action": "$role", "template": true}
|
||||
]
|
||||
},
|
||||
"rejected": {
|
||||
"sections": [
|
||||
{"field": "rejected", "action": "$role", "template": true}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"mask": {
|
||||
"user": "mask",
|
||||
"assistant": "train"
|
||||
},
|
||||
"mask_default": "mask"
|
||||
}
|
||||
```
|
||||
|
||||
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
|
||||
|
||||
### GRPO
|
||||
|
||||
Input JSONL:
|
||||
|
||||
```json
|
||||
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
|
||||
```
|
||||
|
||||
Config:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"sources": {
|
||||
"prompts": {
|
||||
"sections": [
|
||||
{"field": "prompt", "action": "mask", "template": true}
|
||||
]
|
||||
},
|
||||
"responses": {
|
||||
"sections": [
|
||||
{"field": "responses", "action": "train"}
|
||||
],
|
||||
"list_field": true,
|
||||
"mask_key": "masks"
|
||||
},
|
||||
"rewards": {
|
||||
"sections": [
|
||||
{"field": "rewards", "action": "value"}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"mask": {
|
||||
"user": "mask",
|
||||
"assistant": "train"
|
||||
},
|
||||
"mask_default": "mask"
|
||||
}
|
||||
```
|
||||
|
||||
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
|
||||
|
||||
- `action: "value"` — extract raw values from JSONL without tokenisation
|
||||
- `list_field: true` — tokenise each list element independently, then concatenate
|
||||
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### `input`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
|
||||
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
|
||||
|
||||
When `sources` is set, `sections` is ignored.
|
||||
|
||||
### `mask`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
|
||||
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
|
||||
|
||||
### `preprocessing`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
|
||||
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
|
||||
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
|
||||
| `max_items` | int or null | `null` | Stop after N documents |
|
||||
|
||||
### `output`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
|
||||
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
|
||||
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
|
||||
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
|
||||
|
||||
---
|
||||
|
||||
## Mask Algorithm
|
||||
|
||||
### Template mode (`template: true`)
|
||||
|
||||
For each message in the field's array:
|
||||
|
||||
1. Prepend BOS token (masked)
|
||||
2. Render through `chat_template` for that single message
|
||||
3. Encode rendered text
|
||||
4. Apply mask rule for the message's role
|
||||
|
||||
### Non-template mode
|
||||
|
||||
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
|
||||
|
||||
### Text config detection
|
||||
|
||||
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
|
||||
|
||||
---
|
||||
|
||||
## Output Layout
|
||||
|
||||
### Single-Shard (`bin`)
|
||||
|
||||
```
|
||||
output/
|
||||
__default__/
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
wiki/
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
```
|
||||
|
||||
### Multi-Shard (`bin`)
|
||||
|
||||
When `max_tokens_per_shard` is exceeded:
|
||||
|
||||
```
|
||||
output/
|
||||
__default__/
|
||||
shard_0000/
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
shard_0001/
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
```
|
||||
|
||||
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
|
||||
|
||||
---
|
||||
|
||||
## CLI
|
||||
|
||||
```bash
|
||||
# SFT
|
||||
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
|
||||
|
||||
# DPO
|
||||
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
|
||||
|
||||
# GRPO
|
||||
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Python API
|
||||
|
||||
```python
|
||||
from astrai.preprocessing.pipeline import Pipeline
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
|
||||
config = PipelineConfig.from_json("sft.json")
|
||||
Pipeline(
|
||||
config,
|
||||
["data_part1.jsonl", "data_part2.jsonl"],
|
||||
output_dir="output/",
|
||||
tokenizer_path="params",
|
||||
).run()
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-06-03
|
||||
|
|
@ -1,5 +1,38 @@
|
|||
# Training
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, Chinese–English bilingual.
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph Layers["Transformer Layers"]
|
||||
direction TB
|
||||
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
||||
B --> C[Transformer Block\nLayer ...]
|
||||
C --> D[Transformer Block\nLayer ...]
|
||||
D --> E[RMSNorm]
|
||||
E --> F[Linear]
|
||||
F --> G[SoftMax]
|
||||
end
|
||||
|
||||
subgraph TransformerBlock["Transformer Block"]
|
||||
direction TB
|
||||
H[x] --> I[RMSNorm]
|
||||
I --> J[Linear → Q/K/V]
|
||||
J --> K[Q]; J --> L[K]; J --> M[V]
|
||||
K --> N[RoPE]; L --> O[RoPE]
|
||||
N --> P["Q @ K^T / sqrt(d)"]; O --> P
|
||||
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
|
||||
R --> S[Linear]; S --> T[+]; H --> T
|
||||
T --> U[RMSNorm]
|
||||
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
|
||||
V --> X[SiLU]; X --> Y[×]; W --> Y
|
||||
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
|
||||
AA --> BB[x']
|
||||
end
|
||||
```
|
||||
|
||||
### Autoregression
|
||||
|
||||
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||||
|
|
@ -36,24 +69,20 @@ Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_
|
|||
|
||||
```
|
||||
on_train_begin
|
||||
model.train()
|
||||
on_epoch_begin
|
||||
for batch in dataloader:
|
||||
on_batch_begin
|
||||
with executor.accumulate(model):
|
||||
loss = strategy.compute_loss(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / executor.grad_accum_steps
|
||||
executor.backward(stand_loss)
|
||||
context.iteration += 1
|
||||
on_batch_end
|
||||
loss = strategy(batch)
|
||||
(loss / grad_accum_steps).backward()
|
||||
iteration += 1
|
||||
on_batch_end
|
||||
|
||||
if executor.sync_gradients:
|
||||
on_optimizer_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if scheduler:
|
||||
scheduler.step()
|
||||
if iteration % grad_accum_steps == 0:
|
||||
on_step_begin
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
on_step_end
|
||||
scheduler.step()
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
|
@ -63,15 +92,12 @@ on_train_end
|
|||
| Hook | Fires | Default callback |
|
||||
|------|-------|-----------------|
|
||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
|
||||
| `on_batch_begin` | Every batch | — |
|
||||
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
|
||||
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
|
||||
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
|
||||
| `on_step_end` | Every accumulation window | `ValidationCallback` |
|
||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||
|
||||
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||
|
||||
## Strategies
|
||||
|
||||
|
|
@ -83,7 +109,7 @@ $$
|
|||
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
||||
$$
|
||||
|
||||
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
|
||||
Keys: `input_ids`, `target_ids`
|
||||
|
||||
### SFT (Supervised Fine-Tuning)
|
||||
|
||||
|
|
@ -93,7 +119,7 @@ $$
|
|||
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
||||
$$
|
||||
|
||||
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
|
||||
Keys: `input_ids`, `target_ids`, `loss_mask`
|
||||
|
||||
### DPO (Direct Preference Optimization)
|
||||
|
||||
|
|
@ -103,7 +129,7 @@ $$
|
|||
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
||||
$$
|
||||
|
||||
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||||
Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||||
|
||||
### GRPO (Group Relative Policy Optimization)
|
||||
|
||||
|
|
@ -117,7 +143,7 @@ $$
|
|||
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
||||
$$
|
||||
|
||||
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
|
||||
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`.
|
||||
|
||||
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||
|
||||
|
|
@ -128,7 +154,7 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
|||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||
|
||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
|
|
@ -144,30 +170,29 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
|
|||
## Checkpoint
|
||||
|
||||
```
|
||||
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
||||
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
||||
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
|
||||
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
||||
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
|
||||
└── load(save_dir) broadcasts metadata from rank-0
|
||||
```
|
||||
|
||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
||||
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
|
||||
|
||||
## TrainContextBuilder (Builder Pattern)
|
||||
|
||||
```python
|
||||
context = (
|
||||
TrainContextBuilder(config)
|
||||
.with_resume_dir(resume_dir)
|
||||
.with_checkpoint(checkpoint)
|
||||
.build()
|
||||
)
|
||||
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||||
```
|
||||
|
||||
- Loads checkpoint weights if provided
|
||||
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||
- Wraps model with `parallel_wrapper` if `nprocs > 1`
|
||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
|
||||
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
||||
|
||||
## Training CLI
|
||||
|
||||
|
|
@ -176,7 +201,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--parallel_mode=ddp \
|
||||
--train_type=seq \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
|
|
@ -198,4 +222,4 @@ nohup python scripts/tools/train.py \
|
|||
|
||||
Full parameter reference at [params.md](params.md).
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
> Document Update Time: 2026-05-17
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
__version__ = "1.3.7"
|
||||
__version__ = "1.3.6"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from astrai.config import (
|
||||
|
|
|
|||
|
|
@ -4,22 +4,13 @@ from astrai.config.model_config import (
|
|||
ConfigFactory,
|
||||
EncoderConfig,
|
||||
)
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
OutputConfig,
|
||||
PipelineConfig,
|
||||
ProcessingConfig,
|
||||
)
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
||||
__all__ = [
|
||||
# Model configuration
|
||||
"BaseModelConfig",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"ConfigFactory",
|
||||
"TrainConfig",
|
||||
"InputConfig",
|
||||
"OutputConfig",
|
||||
"PipelineConfig",
|
||||
"ProcessingConfig",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -14,21 +13,12 @@ class BaseConfig:
|
|||
d[fld.name] = v
|
||||
elif v is None:
|
||||
d[fld.name] = None
|
||||
elif isinstance(v, (dict, list, tuple)):
|
||||
elif isinstance(v, (dict, list)):
|
||||
try:
|
||||
val = list(v) if isinstance(v, tuple) else v
|
||||
json.dumps(val)
|
||||
d[fld.name] = val
|
||||
json.dumps(v)
|
||||
d[fld.name] = v
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
elif isinstance(v, BaseConfig):
|
||||
d[fld.name] = v.to_dict()
|
||||
elif hasattr(v, "__dataclass_fields__"):
|
||||
sub = {}
|
||||
for f in fields(v):
|
||||
a = getattr(v, f.name)
|
||||
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
||||
d[fld.name] = sub
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
|
@ -84,15 +74,4 @@ class BaseConfig:
|
|||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
||||
return target_type.from_dict(value)
|
||||
raise TypeError
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
def to_json(self, path: Union[str, Path]):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class AutoRegressiveLMConfig(BaseModelConfig):
|
|||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
|
||||
attn_type: str = "gqa"
|
||||
n_heads: Optional[int] = None
|
||||
|
|
@ -81,7 +80,6 @@ class EncoderConfig(BaseModelConfig):
|
|||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
|
|
|
|||
|
|
@ -1,109 +0,0 @@
|
|||
"""Pipeline configuration for JSONL preprocessing.
|
||||
|
||||
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
|
||||
modes, both driven declaratively through ``input.sections`` or
|
||||
``input.sources``.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputConfig(BaseConfig):
|
||||
"""Declarative input mapping.
|
||||
|
||||
Single-output mode (backward-compatible)::
|
||||
|
||||
{"input": {"sections": [{"field": "messages", ...}]}}
|
||||
|
||||
Multi-output mode (DPO / GRPO)::
|
||||
|
||||
{"input": {"sources": {
|
||||
"chosen": {"sections": [{"field": "chosen", ...}]},
|
||||
"rejected": {"sections": [{"field": "rejected", ...}]},
|
||||
}}}
|
||||
"""
|
||||
|
||||
sections: Optional[List[Dict]] = None
|
||||
sources: Optional[Dict[str, Dict]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig(BaseConfig):
|
||||
"""Processing configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_seq_len : int
|
||||
Maximum sequence length (default: 2048).
|
||||
min_chars : int
|
||||
Minimum number of characters to keep (default: 50).
|
||||
max_chars : int
|
||||
Maximum number of characters to keep (default: 2_000_000).
|
||||
max_items : Optional[int]
|
||||
Maximum number of items to process (default: None, unlimited).
|
||||
packing_strategy : str
|
||||
How to pack sequences into a contiguous stream.
|
||||
|
||||
- ``"simple"``: sequential concatenation (default, backward compatible).
|
||||
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
|
||||
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
|
||||
max_packed_len : int
|
||||
Maximum length of a packed bin. Sequences longer than this are
|
||||
truncated or split depending on ``packing_strategy`` (default: 8192).
|
||||
truncation_mode : str
|
||||
How to truncate sequences longer than ``max_packed_len``.
|
||||
|
||||
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
|
||||
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
|
||||
"""
|
||||
|
||||
max_seq_len: int = 2048
|
||||
min_chars: int = 50
|
||||
max_chars: int = 2_000_000
|
||||
max_items: Optional[int] = None
|
||||
packing_strategy: str = "simple"
|
||||
max_packed_len: int = 8192
|
||||
truncation_mode: str = "keep_start"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig(BaseConfig):
|
||||
"""Output configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
domain_key : Optional[str]
|
||||
Domain key for the output store (default: None).
|
||||
storage_format : str
|
||||
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
|
||||
max_tokens_per_shard : int
|
||||
Maximum tokens per shard before splitting (default: 100_000_000).
|
||||
dtype : Dict[str, str]
|
||||
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
|
||||
position_ids_mode : Optional[str]
|
||||
How to compute position_ids in packed sequences.
|
||||
|
||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||
"""
|
||||
|
||||
domain_key: Optional[str] = None
|
||||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
dtype: Dict[str, str] = field(default_factory=dict)
|
||||
position_ids_mode: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig(BaseConfig):
|
||||
version: int = 1
|
||||
input: InputConfig = field(default_factory=InputConfig)
|
||||
mask: Dict[str, str] = field(default_factory=dict)
|
||||
mask_default: str = "mask"
|
||||
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
|
|
@ -7,7 +7,6 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
from astrai.model.components.lora import LoRAConfig
|
||||
|
||||
|
||||
def required(**kw):
|
||||
|
|
@ -17,8 +16,8 @@ def required(**kw):
|
|||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model_fn: Callable[[], nn.Module] = field(
|
||||
default=None, metadata=required(help="Model factory for training.")
|
||||
model: nn.Module = field(
|
||||
default=None, metadata=required(help="Model for training.")
|
||||
)
|
||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||
dataset: Dataset = field(
|
||||
|
|
@ -57,12 +56,6 @@ class TrainConfig(BaseConfig):
|
|||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
|
||||
# lora setting
|
||||
lora: Optional[LoRAConfig] = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA config. None means full fine-tuning."},
|
||||
)
|
||||
|
||||
# metric setting
|
||||
log_dir: str = field(
|
||||
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
||||
|
|
@ -102,9 +95,11 @@ class TrainConfig(BaseConfig):
|
|||
master_port: str = field(
|
||||
default="29500", metadata={"help": "Master port for distributed training."}
|
||||
)
|
||||
parallel_mode: str = field(
|
||||
default="none",
|
||||
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
||||
parallel_wrapper: Optional[Callable] = field(
|
||||
default=None, metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
state_dict_fn: Optional[Callable] = field(
|
||||
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||
)
|
||||
start_method: str = field(
|
||||
default="spawn",
|
||||
|
|
@ -118,21 +113,11 @@ class TrainConfig(BaseConfig):
|
|||
val_dataset: Optional[Dataset] = field(
|
||||
default=None, metadata={"help": "Dataset for validation."}
|
||||
)
|
||||
val_split: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
|
||||
},
|
||||
)
|
||||
val_step: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||
)
|
||||
|
||||
executor_kwargs: dict = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||
)
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict, metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,28 +4,32 @@ from astrai.dataset.dataset import (
|
|||
)
|
||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||
from astrai.dataset.storage import (
|
||||
H5Store,
|
||||
MmapStore,
|
||||
Store,
|
||||
StoreFactory,
|
||||
BaseSegmentFetcher,
|
||||
BaseStorage,
|
||||
H5Storage,
|
||||
JSONStorage,
|
||||
MultiSegmentFetcher,
|
||||
StorageFactory,
|
||||
detect_format,
|
||||
load_bin,
|
||||
load_h5,
|
||||
save_bin,
|
||||
load_json,
|
||||
save_h5,
|
||||
save_json,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
"DatasetFactory",
|
||||
"Store",
|
||||
"StoreFactory",
|
||||
"H5Store",
|
||||
"MmapStore",
|
||||
"BaseSegmentFetcher",
|
||||
"MultiSegmentFetcher",
|
||||
"BaseStorage",
|
||||
"H5Storage",
|
||||
"JSONStorage",
|
||||
"StorageFactory",
|
||||
"detect_format",
|
||||
"save_h5",
|
||||
"load_h5",
|
||||
"save_bin",
|
||||
"load_bin",
|
||||
"save_json",
|
||||
"load_json",
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from torch import Tensor
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.dataset.storage import (
|
||||
Store,
|
||||
StoreFactory,
|
||||
BaseStorage,
|
||||
StorageFactory,
|
||||
detect_format,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
|
|
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
|
|||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.storage: Optional[Store] = None
|
||||
self.storage: Optional[BaseStorage] = None
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
|
|
@ -48,26 +48,37 @@ class BaseDataset(Dataset, ABC):
|
|||
f"Missing: {missing}"
|
||||
)
|
||||
|
||||
def load(self, load_path: str, storage_type: Optional[str] = None):
|
||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
||||
"""Load dataset from the given path.
|
||||
|
||||
Auto-detects the storage format if not specified.
|
||||
|
||||
Args:
|
||||
load_path: Path to the data directory or file
|
||||
storage_type: Force a specific storage type ("h5", "bin"),
|
||||
storage_type: Force a specific storage type ("h5", "json"),
|
||||
or None for auto-detection
|
||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
||||
in JSON files. Ignored for HDF5.
|
||||
|
||||
Raises:
|
||||
KeyError: If the loaded storage is missing required keys.
|
||||
"""
|
||||
if storage_type is None:
|
||||
storage_type = detect_format(load_path)
|
||||
self.storage = StoreFactory.create(storage_type)
|
||||
self.storage = StorageFactory.create(storage_type)
|
||||
self._load_path = load_path
|
||||
self.storage.load(load_path)
|
||||
self.storage.load(load_path, tokenizer=tokenizer)
|
||||
self._validate_keys()
|
||||
|
||||
def load_json(self, load_path: str, tokenizer=None):
|
||||
"""Load dataset from JSON files explicitly.
|
||||
|
||||
Args:
|
||||
load_path: Path to the JSON data file or directory
|
||||
tokenizer: Optional tokenizer callable for raw text JSON.
|
||||
"""
|
||||
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||
|
|
@ -137,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, dataset_cls: type):
|
||||
def _validate_component(cls, dataset_cls: type) -> None:
|
||||
"""Validate that the dataset class inherits from BaseDataset."""
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
|
|
@ -164,6 +175,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
tokenizer=None,
|
||||
) -> "BaseDataset":
|
||||
"""Create and load a dataset in one step.
|
||||
|
||||
|
|
@ -172,7 +184,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
load_path: Path to the data file
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples (default: same as window_size)
|
||||
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
||||
storage_type: Storage type ("h5", "json") or None for auto-detection
|
||||
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
||||
|
||||
Returns:
|
||||
Loaded dataset instance
|
||||
|
|
@ -181,7 +194,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
stride = window_size
|
||||
|
||||
dataset = cls.create(train_type, window_size, stride)
|
||||
dataset.load(load_path, storage_type=storage_type)
|
||||
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
@ -223,7 +236,7 @@ class SFTDataset(BaseDataset):
|
|||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["sequence", "loss_mask", "position_ids"]
|
||||
return ["sequence", "loss_mask"]
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
|
@ -231,17 +244,15 @@ class SFTDataset(BaseDataset):
|
|||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence")
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
|
||||
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
|
||||
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": x.to(dtype=torch.long),
|
||||
"target_ids": y.to(dtype=torch.long),
|
||||
"position_ids": position_ids.to(dtype=torch.long),
|
||||
"loss_mask": loss_mask.to(dtype=torch.bool),
|
||||
}
|
||||
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||
|
||||
|
||||
@DatasetFactory.register("dpo")
|
||||
|
|
@ -295,11 +306,9 @@ class GRPODataset(BaseDataset):
|
|||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
||||
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
|
||||
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ class ResumableDistributedSampler(Sampler[int]):
|
|||
offset = 0 if drop_last else self.num_replicas - 1
|
||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||
self.iter = self.iter % self.num_samples_per_replica
|
||||
|
||||
self._indices = None
|
||||
|
||||
|
|
@ -75,10 +74,5 @@ class ResumableDistributedSampler(Sampler[int]):
|
|||
self.epoch += 1
|
||||
self._indices = None
|
||||
|
||||
@property
|
||||
def _remaining(self):
|
||||
remaining = self.num_samples_per_replica - self.iter
|
||||
return max(remaining, 0)
|
||||
|
||||
def __len__(self):
|
||||
return self._remaining
|
||||
return self.num_samples_per_replica
|
||||
|
|
|
|||
|
|
@ -1,32 +1,17 @@
|
|||
"""Storage backends for different data formats.
|
||||
|
||||
Layers:
|
||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
||||
return Dict[str, List[Tensor]] — format-specific, no state
|
||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||
Dict[str, List[Tensor]] per key via _normalize(),
|
||||
fetch() uses bisect across segments — no forced concat
|
||||
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||
|
||||
Key properties:
|
||||
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||
datasets larger than RAM
|
||||
- Explicit length: _length = min(total elements across keys), set at load,
|
||||
__len__ returns O(1)
|
||||
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||
workers share OS page-cache pages
|
||||
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
||||
a uniform interface for data access and length observation via fetchers.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
|
@ -69,30 +54,54 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
|||
return tensor_group
|
||||
|
||||
|
||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
meta = {}
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
||||
json_data = {}
|
||||
for key, tensors in tensor_group.items():
|
||||
cat = torch.cat(tensors, dim=0)
|
||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
||||
json.dump(meta, f)
|
||||
json_data[key] = [tensor.tolist() for tensor in tensors]
|
||||
with open(full_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
||||
meta = json.load(f)
|
||||
segments: Dict[str, List[Tensor]] = {}
|
||||
for key, info in meta.items():
|
||||
arr = np.memmap(
|
||||
os.path.join(file_path, f"{key}.bin"),
|
||||
dtype=info["dtype"],
|
||||
mode="r+",
|
||||
shape=tuple(info["shape"]),
|
||||
)
|
||||
segments[key] = [torch.from_numpy(arr)]
|
||||
return segments
|
||||
def load_json(
|
||||
file_path: str,
|
||||
share_memory: bool = True,
|
||||
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
||||
) -> Dict[str, List[Tensor]]:
|
||||
"""Load tensor data from JSON files.
|
||||
|
||||
Supports two modes:
|
||||
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
||||
at load time. A ``tokenizer`` receives a str and returns List[int].
|
||||
|
||||
Non-data JSON files (e.g. config.json) with scalar/object values are
|
||||
silently skipped.
|
||||
"""
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
root_path = Path(file_path)
|
||||
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
||||
for json_file in json_files:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for key, sequences in data.items():
|
||||
if not isinstance(sequences, list):
|
||||
continue
|
||||
tensors = []
|
||||
for seq in sequences:
|
||||
if tokenizer is not None and isinstance(seq, str):
|
||||
seq = tokenizer(seq)
|
||||
tensor = torch.tensor(seq, dtype=torch.long)
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
tensors.append(tensor)
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(tensors)
|
||||
return tensor_group
|
||||
|
||||
|
||||
def detect_format(load_path: str) -> str:
|
||||
|
|
@ -102,7 +111,7 @@ def detect_format(load_path: str) -> str:
|
|||
load_path: Directory or file path
|
||||
|
||||
Returns:
|
||||
Format string ("h5" or "bin")
|
||||
Format string ("h5" or "json")
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If no supported data files are found
|
||||
|
|
@ -112,160 +121,181 @@ def detect_format(load_path: str) -> str:
|
|||
suffix = root.suffix.lower()
|
||||
if suffix in (".h5", ".hdf5"):
|
||||
return "h5"
|
||||
if suffix in (".json", ".jsonl"):
|
||||
return "json"
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
|
||||
h5_files = [
|
||||
Path(p)
|
||||
for pattern in ("*.h5", "*.hdf5")
|
||||
for p in glob.glob(str(root / "**" / pattern), recursive=True)
|
||||
]
|
||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||
if h5_files:
|
||||
return "h5"
|
||||
bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
|
||||
if bin_files:
|
||||
has_meta = (root / "meta.json").exists() or len(
|
||||
[Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
|
||||
) > 0
|
||||
if has_meta:
|
||||
return "bin"
|
||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||
if json_files:
|
||||
return "json"
|
||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||
|
||||
|
||||
class Store(ABC):
|
||||
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
|
||||
Each key maps to one or more tensor segments (no forced concatenation).
|
||||
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
||||
total element count across all keys.
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
|
||||
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
||||
via ``_normalize()``.
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
|
||||
self.total_length = total
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx)."""
|
||||
if not (
|
||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||
):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
result_segments = []
|
||||
|
||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
result_segments.append(self.segments[i][start:end])
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys."""
|
||||
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
if not self.multi_fetchers:
|
||||
return 0
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||
) -> Dict:
|
||||
"""Fetch data for specific keys."""
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
"""Abstract storage backend for loading and dispatching data.
|
||||
|
||||
Storage encapsulates format-specific loading and provides a uniform
|
||||
interface for data access and length observation. Subclasses handle
|
||||
different data formats (HDF5, JSON, etc.) while exposing the same
|
||||
fetch interface.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._data: Dict[str, List[Tensor]] = {}
|
||||
self._cum: Dict[str, List[int]] = {}
|
||||
self._length: int = 0
|
||||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||
|
||||
@abstractmethod
|
||||
def load(self, path: str) -> None:
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
"""Load data from the given path into internal fetcher."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Total number of raw elements (tokens) in storage."""
|
||||
if self._fetcher is None:
|
||||
return 0
|
||||
return len(self._fetcher)
|
||||
|
||||
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
||||
"""Fetch data for the given keys and index range.
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
Returns:
|
||||
Tensor if single key, Dict[str, Tensor] if multiple keys
|
||||
"""
|
||||
if self._fetcher is None:
|
||||
raise RuntimeError("Storage not loaded")
|
||||
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
begin: int,
|
||||
end: int,
|
||||
keys: Union[str, List[str]],
|
||||
):
|
||||
if not self._data:
|
||||
raise RuntimeError("Store not loaded")
|
||||
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
||||
raise ValueError(
|
||||
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
||||
)
|
||||
if isinstance(keys, str):
|
||||
return self._fetch_key(keys, begin, end)
|
||||
return {k: self._fetch_key(k, begin, end) for k in keys}
|
||||
|
||||
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
||||
"""Fetch slice [begin, end) across potentially multiple segments."""
|
||||
segments = self._data[key]
|
||||
cum = self._cum[key]
|
||||
seg_start = bisect.bisect_right(cum, begin)
|
||||
seg_end = bisect.bisect_left(cum, end)
|
||||
|
||||
results = []
|
||||
for i in range(seg_start, seg_end + 1):
|
||||
prev = cum[i - 1] if i > 0 else 0
|
||||
s = max(begin - prev, 0)
|
||||
e = min(end - prev, segments[i].shape[0])
|
||||
results.append(segments[i][s:e])
|
||||
|
||||
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
||||
|
||||
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
||||
"""Register segments and pre-compute cumulative lengths.
|
||||
|
||||
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
||||
large datasets. Sets ``self._length`` to the minimum total
|
||||
element count across all keys.
|
||||
"""
|
||||
for key, tensors in raw.items():
|
||||
self._data[key] = tensors
|
||||
cum = []
|
||||
total = 0
|
||||
for t in tensors:
|
||||
total += t.shape[0]
|
||||
cum.append(total)
|
||||
self._cum[key] = cum
|
||||
self._length = (
|
||||
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
||||
if self._cum
|
||||
else 0
|
||||
)
|
||||
"""Return the data keys available in this storage."""
|
||||
if self._fetcher is None:
|
||||
return []
|
||||
return self._fetcher.multi_keys
|
||||
|
||||
|
||||
class StoreFactory(BaseFactory["Store"]):
|
||||
"""Factory for creating Store instances by type name.
|
||||
class StorageFactory(BaseFactory["BaseStorage"]):
|
||||
"""Factory for creating storage backends by type name.
|
||||
|
||||
Example::
|
||||
|
||||
@StoreFactory.register("custom")
|
||||
class CustomStore(Store):
|
||||
Example:
|
||||
@StorageFactory.register("custom")
|
||||
class CustomStorage(BaseStorage):
|
||||
...
|
||||
|
||||
storage = StorageFactory.create("custom")
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, store_cls: type):
|
||||
if not issubclass(store_cls, Store):
|
||||
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
||||
def _validate_component(cls, storage_cls: type) -> None:
|
||||
if not issubclass(storage_cls, BaseStorage):
|
||||
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
||||
|
||||
|
||||
@StoreFactory.register("h5")
|
||||
class H5Store(Store):
|
||||
@StorageFactory.register("h5")
|
||||
class H5Storage(BaseStorage):
|
||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||
|
||||
def load(self, path: str):
|
||||
self._normalize(load_h5(path))
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_h5(load_path)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
@StoreFactory.register("bin")
|
||||
class MmapStore(Store):
|
||||
"""Memory-mapped binary storage backend.
|
||||
@StorageFactory.register("json")
|
||||
class JSONStorage(BaseStorage):
|
||||
"""JSON-based storage backend.
|
||||
|
||||
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
||||
No per-process memory duplication — all DataLoader workers share the
|
||||
same OS page-cache pages.
|
||||
|
||||
Format on disk::
|
||||
|
||||
data_root/
|
||||
meta.json # {key: {shape, dtype}, ...}
|
||||
<key>.bin # raw numpy array, one per key
|
||||
Supports two modes:
|
||||
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
||||
callable (str -> List[int]) at load time.
|
||||
"""
|
||||
|
||||
def load(self, path: str):
|
||||
self._mmap_refs = []
|
||||
root = Path(path)
|
||||
all_raw: Dict[str, List[Tensor]] = {}
|
||||
meta_paths = [
|
||||
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
|
||||
]
|
||||
for meta_path in meta_paths:
|
||||
raw = load_bin(str(meta_path.parent))
|
||||
for key, tensors in raw.items():
|
||||
if key not in all_raw:
|
||||
all_raw[key] = []
|
||||
all_raw[key].extend(tensors)
|
||||
if not meta_paths:
|
||||
raise FileNotFoundError(f"No meta.json found under {path}")
|
||||
self._normalize(all_raw)
|
||||
for tensors in self._data.values():
|
||||
self._mmap_refs.extend(tensors)
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_json(load_path, tokenizer=tokenizer)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class Registry:
|
|||
component_cls: Type,
|
||||
category: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
"""Register a component class with optional category and priority."""
|
||||
if name in self._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
|
|
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
|
|||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]):
|
||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||
"""Validate that the component class is valid for this factory.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
|
|
|||
|
|
@ -1,27 +1,25 @@
|
|||
"""Inference module for continuous batching.
|
||||
|
||||
Layers:
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP orchestration (ProtocolHandler, server)
|
||||
- protocols/: Response builders (OpenAI, Anthropic)
|
||||
- transport/: SSE transport utilities
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
"""
|
||||
|
||||
from astrai.inference.api import (
|
||||
AnthropicHandler,
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
GenContext,
|
||||
MessagesRequest,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
get_app,
|
||||
StreamContext,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.core import (
|
||||
STOP,
|
||||
Allocator,
|
||||
|
|
@ -38,7 +36,10 @@ from astrai.inference.core import (
|
|||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||
from astrai.inference.engine import (
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.inference.sample import (
|
||||
BaseSamplingStrategy,
|
||||
SamplingPipeline,
|
||||
|
|
@ -49,14 +50,17 @@ from astrai.inference.sample import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
# Engine / Requests
|
||||
"InferenceEngine",
|
||||
"GenerationRequest",
|
||||
# Core scheduler
|
||||
"InferenceScheduler",
|
||||
"Executor",
|
||||
"STOP",
|
||||
"Task",
|
||||
"TaskManager",
|
||||
"TaskStatus",
|
||||
# Core cache
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
|
|
@ -65,21 +69,24 @@ __all__ = [
|
|||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
# Sampling (Strategy pattern)
|
||||
"sample",
|
||||
"BaseSamplingStrategy",
|
||||
"TemperatureStrategy",
|
||||
"TopKStrategy",
|
||||
"TopPStrategy",
|
||||
"SamplingPipeline",
|
||||
# Protocol
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"GenContext",
|
||||
"OpenAIResponseBuilder",
|
||||
"AnthropicResponseBuilder",
|
||||
"StreamContext",
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
# Server
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"AnthropicMessage",
|
||||
"MessagesRequest",
|
||||
"get_app",
|
||||
"app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,27 +1,31 @@
|
|||
"""Inference API: protocol handler, stop checker, and FastAPI server.
|
||||
"""Inference API: protocol handlers and FastAPI server."""
|
||||
|
||||
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
||||
lazy singleton FastAPI instance.
|
||||
"""
|
||||
|
||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||
from astrai.inference.api.protocol import (
|
||||
AnthropicHandler,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
)
|
||||
from astrai.inference.api.server import (
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
MessagesRequest,
|
||||
get_app,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"GenContext",
|
||||
"StreamContext",
|
||||
"AnthropicMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatMessage",
|
||||
"MessagesRequest",
|
||||
"get_app",
|
||||
"app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,141 +0,0 @@
|
|||
"""Anthropic message completion response builder."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class AnthropicResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in request.messages:
|
||||
text = _extract_text(m.content)
|
||||
if text:
|
||||
messages.append({"role": m.role, "content": text})
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
ctx = GenContext(
|
||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
events: List[str] = []
|
||||
if stop.matched:
|
||||
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
||||
unyielded = trimmed[len(stop.yielded) :]
|
||||
if unyielded:
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": unyielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
if stop.matched:
|
||||
content = content[: content.rfind(stop.matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
"""OpenAI chat completion response builder."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNSUPPORTED_PARAMS = (
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
)
|
||||
|
||||
|
||||
class OpenAIResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
self._model = request.model
|
||||
|
||||
for param in _UNSUPPORTED_PARAMS:
|
||||
value = getattr(request, param, None)
|
||||
fields = getattr(type(request), "model_fields", {})
|
||||
default = fields[param].default if param in fields else None
|
||||
if value is not None and value != default:
|
||||
logger.warning(
|
||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||
param,
|
||||
value,
|
||||
)
|
||||
if value is not None and value != default:
|
||||
logger.warning(
|
||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||
param,
|
||||
value,
|
||||
)
|
||||
|
||||
ctx = GenContext(
|
||||
resp_id=self._resp_id,
|
||||
created=int(time.time()),
|
||||
model=self._model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop = request.stop
|
||||
stop_sequences = (
|
||||
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
)
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -1,13 +1,15 @@
|
|||
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
||||
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||||
|
||||
ProtocolHandler orchestrates the async generation loop and delegates
|
||||
protocol-specific formatting to a ResponseBuilder.
|
||||
Template Method + Builder patterns eliminate the 45% code duplication between
|
||||
stream/non-stream branches and across protocol adapters.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -15,7 +17,7 @@ from pydantic import BaseModel
|
|||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
lines: List[str] = []
|
||||
if event:
|
||||
lines.append(f"event: {event}")
|
||||
|
|
@ -24,28 +26,22 @@ def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
|||
return "\n".join(lines)
|
||||
|
||||
|
||||
def sse_done() -> str:
|
||||
def _sse_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenContext:
|
||||
"""Per-generation metadata passed to builder format methods."""
|
||||
class StreamContext:
|
||||
"""Shared state across the streaming generation lifecycle."""
|
||||
|
||||
resp_id: str
|
||||
created: int
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopInfo:
|
||||
"""Stop-check result passed to format_stream_end / format_response."""
|
||||
|
||||
matched: Optional[str] = None
|
||||
body: str = ""
|
||||
yielded: str = ""
|
||||
accumulated: str = ""
|
||||
stop_matched: Optional[str] = None
|
||||
last_yield_trimmed: str = ""
|
||||
|
||||
|
||||
class StopChecker:
|
||||
|
|
@ -60,60 +56,95 @@ class StopChecker:
|
|||
return seq
|
||||
return None
|
||||
|
||||
def trim(self, text: str, matched: str) -> str:
|
||||
idx = text.rfind(matched)
|
||||
return text[:idx] if idx != -1 else text
|
||||
|
||||
class ResponseBuilder(ABC):
|
||||
"""Interface for protocol-specific response formatting.
|
||||
@property
|
||||
def has_sequences(self) -> bool:
|
||||
return len(self._sequences) > 0
|
||||
|
||||
A new protocol requires one concrete builder implementing 5 methods.
|
||||
|
||||
class ProtocolHandler(ABC):
|
||||
"""Template-method base for API protocol handlers.
|
||||
|
||||
Subclasses implement format hooks; the base class orchestrates the
|
||||
generate-async loop and SSE/JSON response construction.
|
||||
|
||||
Lifecycle::
|
||||
|
||||
handle()
|
||||
├─ build_prompt() # protocol-specific prompt assembly
|
||||
├─ create_response_id() # unique response identifier
|
||||
├─ [stream]
|
||||
│ ├─ format_stream_start()
|
||||
│ ├─ format_stream_token() × N
|
||||
│ │ └─ on_token() hook for stop-sequence interception
|
||||
│ └─ format_stream_end()
|
||||
└─ [non-stream]
|
||||
├─ (accumulate tokens)
|
||||
└─ format_non_stream_response()
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
||||
request_model: type[BaseModel]
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
"""SSE events that open the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_chunk(self, token: str) -> str:
|
||||
"""SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
"""SSE events that close the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
"""JSON response body for non-streaming mode."""
|
||||
|
||||
|
||||
class ProtocolHandler:
|
||||
"""Orchestrates the generation loop, delegates formatting to a builder.
|
||||
|
||||
Usage::
|
||||
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
response = await handler.handle()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
||||
):
|
||||
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||||
self.request = request
|
||||
self.engine = engine
|
||||
self.builder = builder
|
||||
|
||||
@abstractmethod
|
||||
def build_prompt(self) -> str:
|
||||
"""Build the full prompt string from the request messages."""
|
||||
|
||||
@abstractmethod
|
||||
def create_response_id(self) -> str:
|
||||
"""Generate a unique response ID following the protocol convention."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
"""Yield an SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the JSON response body for non-streaming mode."""
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def create_stop_checker(self) -> StopChecker:
|
||||
return StopChecker(self.get_stop_sequences())
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
"""Hook after each token is appended to accumulated.
|
||||
|
||||
Return a matched stop-sequence string to break the loop,
|
||||
or None to continue.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
||||
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
||||
ctx = StreamContext(
|
||||
resp_id=self.create_response_id(),
|
||||
created=int(time.time()),
|
||||
model=self.request.model,
|
||||
prompt_tokens=self._count_prompt_tokens(),
|
||||
)
|
||||
|
||||
agen = self.engine.generate_async(
|
||||
prompt=prompt,
|
||||
prompt=self.build_prompt(),
|
||||
max_tokens=self.request.max_tokens,
|
||||
temperature=self.request.temperature,
|
||||
top_p=self.request.top_p,
|
||||
|
|
@ -121,37 +152,33 @@ class ProtocolHandler:
|
|||
)
|
||||
|
||||
if self.request.stream:
|
||||
return self._handle_stream(agen, ctx, stop_sequences)
|
||||
return self._handle_stream(agen, ctx)
|
||||
else:
|
||||
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
||||
return await self._handle_non_stream(agen, ctx)
|
||||
|
||||
def _handle_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> StreamingResponse:
|
||||
checker = StopChecker(stop_sequences)
|
||||
def _count_prompt_tokens(self) -> int:
|
||||
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||||
|
||||
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||||
stop_checker = self.create_stop_checker()
|
||||
|
||||
async def event_stream():
|
||||
for event in self.builder.format_stream_start(ctx):
|
||||
for event in self.format_stream_start(ctx):
|
||||
yield event
|
||||
|
||||
body = ""
|
||||
yielded = ""
|
||||
matched = None
|
||||
async for token in agen:
|
||||
body += token
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
|
||||
matched = checker.check(body)
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
if matched:
|
||||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
yield self.builder.format_chunk(token)
|
||||
yielded += token
|
||||
yield self.format_stream_token(ctx, token)
|
||||
|
||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||
for event in self.builder.format_stream_end(ctx, stop):
|
||||
for event in self.format_stream_end(ctx):
|
||||
yield event
|
||||
yield sse_done()
|
||||
yield _sse_done()
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
|
|
@ -159,24 +186,260 @@ class ProtocolHandler:
|
|||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
async def _handle_non_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
checker = StopChecker(stop_sequences)
|
||||
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||||
stop_checker = self.create_stop_checker()
|
||||
chunks: List[str] = []
|
||||
body = ""
|
||||
matched = None
|
||||
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
chunks.append(token)
|
||||
body += token
|
||||
|
||||
matched = checker.check(body)
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
if matched:
|
||||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
|
||||
content = "".join(chunks)
|
||||
stop = StopInfo(matched=matched, body=body)
|
||||
return self.builder.format_response(ctx, content, stop)
|
||||
return self.format_non_stream_response(ctx, content)
|
||||
|
||||
|
||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract plain text from an Anthropic content block (string or list)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class OpenAIHandler(ProtocolHandler):
|
||||
"""OpenAI-compatible /v1/chat/completions handler."""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages = [
|
||||
{"role": m.role, "content": m.content} for m in self.request.messages
|
||||
]
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
stop = self.request.stop
|
||||
if stop is None:
|
||||
return []
|
||||
return [stop] if isinstance(stop, str) else stop
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
return stop_checker.check(ctx.accumulated)
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
return _sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AnthropicHandler(ProtocolHandler):
|
||||
"""Anthropic-compatible /v1/messages handler."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._yielded = ""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(self.request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in self.request.messages:
|
||||
content = _extract_text_content(m.content)
|
||||
if content:
|
||||
messages.append({"role": m.role, "content": content})
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"msg_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return getattr(self.request, "stop_sequences", None) or []
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
matched = stop_checker.check(ctx.accumulated)
|
||||
if not matched:
|
||||
return None
|
||||
|
||||
ctx.stop_matched = matched
|
||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||
unyielded = trimmed[len(self._yielded) :]
|
||||
if unyielded:
|
||||
ctx.last_yield_trimmed = unyielded
|
||||
return matched
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
self._yielded += token
|
||||
return _sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
matched = ctx.stop_matched
|
||||
events: List[str] = []
|
||||
last_yielded = ctx.last_yield_trimmed
|
||||
if last_yielded:
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": last_yielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
matched = ctx.stop_matched
|
||||
if matched:
|
||||
content = content[: content.rfind(matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,6 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
|
|||
|
||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||
|
||||
``app`` is lazily constructed — importing this module does NOT create a FastAPI instance.
|
||||
Use :func:`get_app` to access the singleton.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -15,19 +12,17 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.api.protocol import ProtocolHandler
|
||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_app_instance: Optional[FastAPI] = None
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
|
@ -87,15 +82,17 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _create_engine(
|
||||
param_path: Path,
|
||||
param_path: Optional[Path] = None,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
max_batch_size: int = 16,
|
||||
) -> InferenceEngine:
|
||||
if param_path is None:
|
||||
param_path = _project_root / "params"
|
||||
if not param_path.exists():
|
||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
|
|
@ -113,66 +110,49 @@ def _create_engine(
|
|||
return engine
|
||||
|
||||
|
||||
def get_app() -> FastAPI:
|
||||
"""Return the singleton FastAPI instance (lazily created on first call)."""
|
||||
global _app_instance
|
||||
if _app_instance is None:
|
||||
_app_instance = FastAPI(
|
||||
title="AstrAI Inference Server",
|
||||
version="0.2.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
_app_instance.include_router(router)
|
||||
_app_instance.state.server_config = {}
|
||||
_app_instance.state.engine = None
|
||||
return _app_instance
|
||||
|
||||
|
||||
def _get_engine() -> InferenceEngine:
|
||||
engine = get_app().state.engine
|
||||
engine = app.state.engine
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return engine
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
app = get_app()
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": app.state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
return _get_engine().get_stats()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
engine = _get_engine()
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
handler = OpenAIHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@router.post("/v1/messages")
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest):
|
||||
engine = _get_engine()
|
||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||
handler = AnthropicHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
def run_server(
|
||||
param_path: Path,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
reload: bool = False,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
app = get_app()
|
||||
app.state.server_config = {
|
||||
"device": device,
|
||||
"dtype": dtype,
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class Allocator:
|
|||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False):
|
||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
|
|
@ -51,7 +51,7 @@ class Allocator:
|
|||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int):
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
self._refs[idx] += 1
|
||||
self._lru.pop(idx, None)
|
||||
|
|
@ -60,7 +60,7 @@ class Allocator:
|
|||
with self._lock:
|
||||
return self._refs[idx]
|
||||
|
||||
def touch(self, idx: int):
|
||||
def touch(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ class PrefixCache:
|
|||
self._hash_to_page: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def evict(self, idx: int):
|
||||
def evict(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
|
|
@ -96,7 +96,9 @@ class PrefixCache:
|
|||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
with self._lock:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
|
|
@ -125,13 +127,13 @@ class PagePool:
|
|||
def alloc(self) -> int:
|
||||
return self._alloc.alloc()
|
||||
|
||||
def free(self, idx: int):
|
||||
def free(self, idx: int) -> None:
|
||||
keep = self._prefix.has_page(idx)
|
||||
self._alloc.free(idx, keep_cached=keep)
|
||||
if not keep:
|
||||
self._prefix.evict(idx)
|
||||
|
||||
def inc_ref(self, idx: int):
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
self._alloc.inc_ref(idx)
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
|
|
@ -140,7 +142,9 @@ class PagePool:
|
|||
self._alloc.touch(p)
|
||||
return hits
|
||||
|
||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||
|
||||
|
||||
|
|
@ -153,7 +157,7 @@ class TaskTable:
|
|||
self._cached: Dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int):
|
||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||
with self._lock:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
|
@ -216,7 +220,7 @@ class Storage:
|
|||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
):
|
||||
) -> None:
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
|
|
@ -282,7 +286,7 @@ class KvcacheView:
|
|||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||
start_pos = self._total_len - k.size(1)
|
||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
|
|
@ -335,7 +339,7 @@ class KVCache:
|
|||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str):
|
||||
def task_free(self, task_id: str) -> None:
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self._pool.free(idx)
|
||||
|
|
@ -355,7 +359,7 @@ class KVCache:
|
|||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
):
|
||||
) -> None:
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ class Executor:
|
|||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
||||
def execute_prefill(
|
||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||
) -> None:
|
||||
if start_pos >= prompt_len:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -71,19 +71,18 @@ class InferenceScheduler:
|
|||
)
|
||||
|
||||
self._running = False
|
||||
self._fatal_error: Optional[Exception] = None
|
||||
|
||||
def add_task(self, prompt: str, **kwargs) -> str:
|
||||
return self._task_mgr.add_task(prompt, **kwargs)
|
||||
|
||||
def remove_task(self, task_id: str):
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
for task in self._task_mgr.remove_task(task_id):
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self._task_mgr.get_stats()
|
||||
|
||||
def _run_generation_loop(self):
|
||||
def _run_generation_loop(self) -> None:
|
||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||
try:
|
||||
while self._running:
|
||||
|
|
@ -109,10 +108,7 @@ class InferenceScheduler:
|
|||
continue
|
||||
|
||||
to_prefill = [
|
||||
t
|
||||
for t in self._task_mgr.get_active_tasks()
|
||||
if t.output_tokens == 0
|
||||
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
||||
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
||||
]
|
||||
if to_prefill:
|
||||
for t in to_prefill:
|
||||
|
|
@ -160,15 +156,11 @@ class InferenceScheduler:
|
|||
t.output_ids.append(ntok)
|
||||
t.output_tokens += 1
|
||||
pos = t.input_tokens + t.output_tokens
|
||||
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
||||
self._page_cache.task_extend(t.task_id, pos)
|
||||
if t.stream_callback:
|
||||
t.stream_callback(
|
||||
self._task_mgr.tokenizer.decode([ntok])
|
||||
)
|
||||
if not extend_ok:
|
||||
t.status = TaskStatus.ABORTED
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
for t in valid:
|
||||
if t.is_finished(stop_ids):
|
||||
|
|
@ -176,37 +168,28 @@ class InferenceScheduler:
|
|||
t.stream_callback(STOP)
|
||||
|
||||
except Exception as e:
|
||||
self._fatal_error = e
|
||||
self._running = False
|
||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
for task in self._task_mgr.get_waiting_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._task_mgr.clear_queues()
|
||||
raise
|
||||
|
||||
def start(self):
|
||||
def start(self) -> None:
|
||||
if not self._running:
|
||||
self._running = True
|
||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||
t.start()
|
||||
self._loop_thread = t
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self._running = False
|
||||
self._task_mgr.wake()
|
||||
if hasattr(self, "_loop_thread"):
|
||||
self._loop_thread.join(timeout=2.0)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
for task in self._task_mgr.get_waiting_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._task_mgr.clear_queues()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -172,12 +172,12 @@ class TaskManager:
|
|||
to_add.append(self.waiting_queue.popleft())
|
||||
return to_add
|
||||
|
||||
def activate(self, task: Task):
|
||||
def activate(self, task: Task) -> None:
|
||||
task.status = TaskStatus.RUNNING
|
||||
with self._lock:
|
||||
self.active_tasks.append(task)
|
||||
|
||||
def return_to_waiting(self, tasks: List[Task]):
|
||||
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||
with self._lock:
|
||||
for task in reversed(tasks):
|
||||
self.waiting_queue.appendleft(task)
|
||||
|
|
@ -185,25 +185,18 @@ class TaskManager:
|
|||
def has_work(self) -> bool:
|
||||
return bool(self.active_tasks or self.waiting_queue)
|
||||
|
||||
def wait_for_tasks(self, timeout: float = 1.0):
|
||||
with self._lock:
|
||||
if self.waiting_queue or self.active_tasks:
|
||||
return
|
||||
self._task_event.clear()
|
||||
def wait_for_tasks(self, timeout: float = 1.0) -> None:
|
||||
self._task_event.clear()
|
||||
self._task_event.wait(timeout=timeout)
|
||||
|
||||
def get_active_tasks(self) -> List[Task]:
|
||||
with self._lock:
|
||||
return list(self.active_tasks)
|
||||
|
||||
def get_waiting_tasks(self) -> List[Task]:
|
||||
with self._lock:
|
||||
return list(self.waiting_queue)
|
||||
|
||||
def clear_queues(self):
|
||||
def clear_queues(self) -> None:
|
||||
with self._lock:
|
||||
self.waiting_queue.clear()
|
||||
self.active_tasks.clear()
|
||||
|
||||
def wake(self):
|
||||
def wake(self) -> None:
|
||||
self._task_event.set()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,17 @@ from astrai.inference.core.task import STOP
|
|||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def _validate_sampling_params(
|
||||
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||
):
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||
|
||||
|
|
@ -48,7 +59,7 @@ class GenerateResult:
|
|||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def wait_completion(self, timeout: float = 300.0):
|
||||
def wait_completion(self, timeout: float = 300.0) -> None:
|
||||
with self._cond:
|
||||
if not self._cond.wait_for(
|
||||
lambda: self._completed >= self._total, timeout=timeout
|
||||
|
|
@ -75,12 +86,7 @@ class GenerationRequest:
|
|||
max_tokens: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
||||
raise ValueError("temperature must be a positive number")
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
|
|
@ -131,6 +137,7 @@ class InferenceEngine:
|
|||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
|
|
@ -151,6 +158,7 @@ class InferenceEngine:
|
|||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
|
@ -281,7 +289,7 @@ class InferenceEngine:
|
|||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.scheduler.get_stats()
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self) -> None:
|
||||
self.scheduler.stop()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -44,12 +44,10 @@ class TemperatureStrategy(BaseSamplingStrategy):
|
|||
def apply(self, logits, filter_value=-float("inf")):
|
||||
t = self.temperature
|
||||
if isinstance(t, Tensor):
|
||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||
t = torch.clamp(t, min=1e-8)
|
||||
if (t != 1.0).any():
|
||||
logits = logits / t
|
||||
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||
elif t != 1.0:
|
||||
logits = logits / max(t, 1e-8)
|
||||
logits = logits / t
|
||||
return logits
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,13 +2,6 @@ from astrai.model.automodel import AutoModel
|
|||
from astrai.model.components.attention import GQA
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.lora import (
|
||||
LoRAConfig,
|
||||
inject_lora,
|
||||
load_lora,
|
||||
merge_lora,
|
||||
save_lora,
|
||||
)
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.encoder import EmbeddingEncoder
|
||||
|
|
@ -25,10 +18,4 @@ __all__ = [
|
|||
"AutoRegressiveLM",
|
||||
"EmbeddingEncoder",
|
||||
"AutoModel",
|
||||
# LoRA
|
||||
"LoRAConfig",
|
||||
"inject_lora",
|
||||
"merge_lora",
|
||||
"save_lora",
|
||||
"load_lora",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,24 +2,21 @@
|
|||
AutoModel base class for model loading and saving.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.serialization import load_model_config, load_model_weights, save_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _disable_random_init(enable: bool = True):
|
||||
if not enable:
|
||||
yield
|
||||
return
|
||||
|
||||
names = (
|
||||
init_functions = [
|
||||
"xavier_normal_",
|
||||
"xavier_uniform_",
|
||||
"kaiming_normal_",
|
||||
|
|
@ -29,15 +26,18 @@ def _disable_random_init(enable: bool = True):
|
|||
"constant_",
|
||||
"normal_",
|
||||
"uniform_",
|
||||
)
|
||||
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
||||
for n in orig:
|
||||
setattr(nn.init, n, lambda *a, **kw: None)
|
||||
]
|
||||
original_funcs = {}
|
||||
for name in init_functions:
|
||||
if enable and hasattr(nn.init, name):
|
||||
original_funcs[name] = getattr(nn.init, name)
|
||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for n, fn in orig.items():
|
||||
setattr(nn.init, n, fn)
|
||||
if enable:
|
||||
for name, orig_func in original_funcs.items():
|
||||
setattr(nn.init, name, orig_func)
|
||||
|
||||
|
||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||
|
|
@ -60,22 +60,25 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
|
||||
model_path = Path(path)
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
raw = json.load(f)
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
else:
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
raw = load_model_config(str(model_path))
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
model = actual_cls(config)
|
||||
|
||||
# Load weights
|
||||
weights_path = model_path / "model.safetensors"
|
||||
if weights_path.exists():
|
||||
state_dict = load_model_weights(str(model_path))
|
||||
state_dict = st.load_file(str(weights_path))
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
return model
|
||||
|
|
@ -83,12 +86,15 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, Path],
|
||||
):
|
||||
save_model(
|
||||
config=self.config.to_dict(),
|
||||
state_dict=self.state_dict(),
|
||||
save_directory=str(save_directory),
|
||||
)
|
||||
) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
self.config.to_file(str(save_path / "config.json"))
|
||||
|
||||
# Save weights
|
||||
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
"""Move model to device/dtype."""
|
||||
|
|
|
|||
|
|
@ -1,194 +0,0 @@
|
|||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.serialization import (
|
||||
load_json,
|
||||
load_safetensors,
|
||||
save_json,
|
||||
save_safetensors,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
||||
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
r: int = 16
|
||||
alpha: int = 32
|
||||
target_modules: tuple = ("q_proj", "v_proj")
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
||||
super().__init__()
|
||||
self.register_parameter("weight", base.weight)
|
||||
self.weight.requires_grad_(False)
|
||||
self.bias = base.bias
|
||||
if self.bias is not None:
|
||||
self.bias.requires_grad_(False)
|
||||
|
||||
self.r = r
|
||||
self.scaling = alpha / r
|
||||
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
||||
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
||||
self._merged = False
|
||||
|
||||
def forward(self, x):
|
||||
out = F.linear(x, self.weight, self.bias)
|
||||
if not self._merged:
|
||||
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
||||
return out
|
||||
|
||||
def merge(self):
|
||||
if self._merged:
|
||||
return
|
||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||
self._merged = True
|
||||
del self.lora_A
|
||||
del self.lora_B
|
||||
|
||||
|
||||
def _collect_lora_info(model: nn.Module) -> dict:
|
||||
names = {}
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, Linear):
|
||||
_, _, child = n.rpartition(".")
|
||||
names.setdefault(child, []).append(n)
|
||||
return names
|
||||
|
||||
|
||||
def _get_lora_count(model: nn.Module) -> int:
|
||||
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
||||
|
||||
|
||||
def inject_lora(
|
||||
model: nn.Module,
|
||||
r: int = 16,
|
||||
alpha: int = 32,
|
||||
target_modules: Optional[Set[str]] = None,
|
||||
) -> LoRAConfig:
|
||||
if target_modules is None:
|
||||
target_modules = TARGET_MODULES_ATTN
|
||||
|
||||
available = _collect_lora_info(model)
|
||||
injected = 0
|
||||
|
||||
for name, module in list(model.named_modules()):
|
||||
if not isinstance(module, Linear):
|
||||
continue
|
||||
parent_name, _, child_name = name.rpartition(".")
|
||||
if child_name not in target_modules:
|
||||
continue
|
||||
parent = model.get_submodule(parent_name) if parent_name else model
|
||||
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
||||
injected += 1
|
||||
|
||||
if injected == 0:
|
||||
logger.warning(
|
||||
"No LoRA layers injected. Available Linear child names: %s. "
|
||||
"target_modules: %s. Check model type and target_modules.",
|
||||
sorted(available),
|
||||
sorted(target_modules),
|
||||
)
|
||||
else:
|
||||
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
||||
|
||||
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
||||
|
||||
|
||||
def merge_lora(model: nn.Module):
|
||||
n = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
module.merge()
|
||||
n += 1
|
||||
if n == 0:
|
||||
logger.warning("No LoRA layers to merge.")
|
||||
else:
|
||||
logger.info("Merged %d LoRA layers", n)
|
||||
|
||||
|
||||
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
||||
lora_sd = {
|
||||
k: v
|
||||
for k, v in model.state_dict().items()
|
||||
if k.endswith((".lora_A", ".lora_B"))
|
||||
}
|
||||
if not lora_sd:
|
||||
raise RuntimeError(
|
||||
"No LoRA parameters found in model. "
|
||||
"The model may not have been injected or was already merged."
|
||||
)
|
||||
|
||||
path = Path(save_dir)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
||||
save_json(asdict(config), path / "adapter_config.json")
|
||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||
|
||||
|
||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||
path = Path(load_dir)
|
||||
raw = load_json(path / "adapter_config.json")
|
||||
config = LoRAConfig(
|
||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||
)
|
||||
|
||||
existing = _get_lora_count(model)
|
||||
if existing > 0:
|
||||
logger.warning(
|
||||
"Model already has %d LoRA layers. Skipping injection, "
|
||||
"loading weights onto existing layers only.",
|
||||
existing,
|
||||
)
|
||||
else:
|
||||
inject_lora(
|
||||
model,
|
||||
r=config.r,
|
||||
alpha=config.alpha,
|
||||
target_modules=set(config.target_modules),
|
||||
)
|
||||
|
||||
weights = load_safetensors(path / "adapter_model.safetensors")
|
||||
try:
|
||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
if "size mismatch" in msg:
|
||||
raise RuntimeError(
|
||||
f"LoRA weight shapes do not match the model. "
|
||||
f"The adapter config (r={config.r}) may not match the injected layers. "
|
||||
f"Original error: {msg}"
|
||||
) from e
|
||||
raise
|
||||
|
||||
injected = _get_lora_count(model)
|
||||
if injected == 0:
|
||||
raise RuntimeError(
|
||||
"No LoRA layers found after loading. "
|
||||
"Inject LoRA before calling load_lora, or check the adapter config."
|
||||
)
|
||||
|
||||
if missing:
|
||||
lora_missing = [k for k in missing if "lora" in k]
|
||||
if lora_missing:
|
||||
raise RuntimeError(
|
||||
f"LoRA weight keys not found in model: {lora_missing}. "
|
||||
f"The adapter config (r={config.r}) may not match the model."
|
||||
)
|
||||
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
||||
if unexpected:
|
||||
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
||||
|
||||
logger.info("LoRA adapter loaded from %s", load_dir)
|
||||
return config
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -19,10 +19,6 @@ def get_rotary_emb(
|
|||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def ntk_base(base: float, dim: int, factor: float) -> float:
|
||||
return base * (factor ** (dim / (dim - 2)))
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
|
|
@ -34,25 +30,11 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
rope_scaling: Optional[Dict] = None,
|
||||
):
|
||||
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
if rope_scaling is not None:
|
||||
scaling_type = rope_scaling.get("type", "ntk")
|
||||
factor = rope_scaling.get("factor", 1.0)
|
||||
if scaling_type == "ntk":
|
||||
self.base = ntk_base(base, dim, factor)
|
||||
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@ class EmbeddingEncoder(AutoModel):
|
|||
self.config = config
|
||||
rope_dim = config.dim // config.n_heads
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||
)
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
|
@ -68,6 +66,9 @@ class EmbeddingEncoder(AutoModel):
|
|||
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Mapping, Optional
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -26,21 +26,24 @@ def process_attention_mask(
|
|||
return input_mask
|
||||
|
||||
device = input_tensor.device
|
||||
B = input_tensor.size(0)
|
||||
dtype = input_tensor.dtype
|
||||
B, S = input_tensor.size()[:2]
|
||||
T = position_ids.max().item() + 1
|
||||
|
||||
if input_mask is None:
|
||||
if position_ids.min().item() == 0 and is_causal:
|
||||
return None
|
||||
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
|
||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||
else:
|
||||
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
|
||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||
|
||||
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
||||
if is_causal:
|
||||
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||
attend = attend & causal
|
||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||
|
||||
return attend.unsqueeze(1)
|
||||
return torch.full(
|
||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||
|
||||
|
||||
@AutoModel.register("autoregressive_lm")
|
||||
|
|
@ -56,9 +59,7 @@ class AutoRegressiveLM(AutoModel):
|
|||
else config.dim // config.n_heads
|
||||
)
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||
)
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
|
@ -133,7 +134,7 @@ class AutoRegressiveLM(AutoModel):
|
|||
input_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Dict[str, Tensor]:
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,3 @@
|
|||
from astrai.parallel.executor import (
|
||||
AccumOptimizer,
|
||||
AccumScheduler,
|
||||
BaseExecutor,
|
||||
DDPExecutor,
|
||||
ExecutorFactory,
|
||||
FSDPExecutor,
|
||||
GradientState,
|
||||
NoneExecutor,
|
||||
)
|
||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||
from astrai.parallel.setup import (
|
||||
get_current_device,
|
||||
|
|
@ -27,12 +17,4 @@ __all__ = [
|
|||
"spawn_parallel_fn",
|
||||
"RowParallelLinear",
|
||||
"ColumnParallelLinear",
|
||||
"ExecutorFactory",
|
||||
"BaseExecutor",
|
||||
"GradientState",
|
||||
"AccumOptimizer",
|
||||
"AccumScheduler",
|
||||
"NoneExecutor",
|
||||
"DDPExecutor",
|
||||
"FSDPExecutor",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,272 +0,0 @@
|
|||
"""Unified training executor — parallel strategy + gradient accumulation."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel.setup import get_rank, get_world_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GradientState:
|
||||
def __init__(self, grad_accum_steps: int = 1):
|
||||
self.num_steps = max(grad_accum_steps, 1)
|
||||
self._step: int = 0
|
||||
self._sync_gradients: bool = True
|
||||
|
||||
@property
|
||||
def sync_gradients(self) -> bool:
|
||||
return self._sync_gradients
|
||||
|
||||
def _do_sync(self):
|
||||
self._step += 1
|
||||
self._sync_gradients = self._step % self.num_steps == 0
|
||||
|
||||
|
||||
class AccumOptimizer:
|
||||
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
||||
self.optimizer = optimizer
|
||||
self.gradient_state = gradient_state
|
||||
|
||||
def step(self, closure=None):
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.optimizer.step(closure)
|
||||
|
||||
def zero_grad(self):
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def state_dict(self):
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
def load_state_dict(self, d):
|
||||
self.optimizer.load_state_dict(d)
|
||||
|
||||
|
||||
class AccumScheduler:
|
||||
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
||||
self.scheduler = scheduler
|
||||
self.gradient_state = gradient_state
|
||||
|
||||
def step(self):
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.scheduler.step()
|
||||
|
||||
def state_dict(self):
|
||||
return self.scheduler.state_dict()
|
||||
|
||||
def load_state_dict(self, d):
|
||||
self.scheduler.load_state_dict(d)
|
||||
|
||||
def get_last_lr(self):
|
||||
return self.scheduler.get_last_lr()
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
def __init__(self, grad_accum_steps: int = 1):
|
||||
self.gradient_state = GradientState(grad_accum_steps)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[
|
||||
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
|
||||
]:
|
||||
model = self._prepare_model(model)
|
||||
if optimizer is not None:
|
||||
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
||||
if scheduler is not None:
|
||||
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
||||
return model, optimizer, dataloader, scheduler
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@contextmanager
|
||||
def accumulate(self, model: nn.Module):
|
||||
self.gradient_state._do_sync()
|
||||
if not self.gradient_state.sync_gradients:
|
||||
with self._no_sync(model):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss.backward()
|
||||
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
return model.state_dict()
|
||||
|
||||
@property
|
||||
def use_distributed(self) -> bool:
|
||||
return get_world_size() > 1
|
||||
|
||||
@property
|
||||
def sync_gradients(self) -> bool:
|
||||
return self.gradient_state.sync_gradients
|
||||
|
||||
@property
|
||||
def grad_accum_steps(self) -> int:
|
||||
return self.gradient_state.num_steps
|
||||
|
||||
|
||||
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
||||
pass
|
||||
|
||||
|
||||
@ExecutorFactory.register("none")
|
||||
class NoneExecutor(BaseExecutor):
|
||||
pass
|
||||
|
||||
|
||||
@ExecutorFactory.register("ddp")
|
||||
class DDPExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
grad_accum_steps: int = 1,
|
||||
dim: int = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
init_sync: bool = True,
|
||||
process_group=None,
|
||||
bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
delay_all_reduce_named_params=None,
|
||||
param_to_hook_all_reduce=None,
|
||||
mixed_precision=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||
self._ddp_kwargs = dict(
|
||||
dim=dim,
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
init_sync=init_sync,
|
||||
process_group=process_group,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
delay_all_reduce_named_params=delay_all_reduce_named_params,
|
||||
param_to_hook_all_reduce=param_to_hook_all_reduce,
|
||||
mixed_precision=mixed_precision,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
if not self.use_distributed:
|
||||
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
||||
return model
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
**self._ddp_kwargs,
|
||||
)
|
||||
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
if isinstance(model, DDP):
|
||||
return model.no_sync()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
if isinstance(model, DDP):
|
||||
return model.module.state_dict()
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
@ExecutorFactory.register("fsdp")
|
||||
class FSDPExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
grad_accum_steps: int = 1,
|
||||
process_group=None,
|
||||
sharding_strategy=None,
|
||||
cpu_offload=None,
|
||||
auto_wrap_policy=None,
|
||||
backward_prefetch=None,
|
||||
mixed_precision=None,
|
||||
ignored_modules=None,
|
||||
param_init_fn=None,
|
||||
sync_module_states: bool = False,
|
||||
forward_prefetch: bool = False,
|
||||
limit_all_gathers: bool = True,
|
||||
ignored_states=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||
self._fsdp_kwargs = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
forward_prefetch=forward_prefetch,
|
||||
limit_all_gathers=limit_all_gathers,
|
||||
use_orig_params=True,
|
||||
ignored_states=ignored_states,
|
||||
device_mesh=device_mesh,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
self._original_model: Optional[nn.Module] = None
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
if not self.use_distributed:
|
||||
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
||||
return model
|
||||
self._original_model = model
|
||||
device_id = torch.device("cuda", get_rank())
|
||||
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
||||
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
if isinstance(model, FSDP):
|
||||
return model.no_sync()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def unwrap_model(self, model: nn.Module):
|
||||
if isinstance(model, FSDP) and self.use_distributed:
|
||||
with FSDP.state_dict_type(
|
||||
model,
|
||||
StateDictType.FULL_STATE_DICT,
|
||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
||||
):
|
||||
return model.state_dict()
|
||||
|
||||
return model.state_dict()
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
|
@ -31,7 +30,6 @@ def get_rank() -> int:
|
|||
def setup_parallel(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
local_rank: int,
|
||||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
|
|
@ -43,18 +41,14 @@ def setup_parallel(
|
|||
return
|
||||
|
||||
if world_size <= 1:
|
||||
device_id = torch.device(device_type, local_rank)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
yield None
|
||||
return
|
||||
|
||||
device_id = torch.device(device_type, local_rank)
|
||||
device_id = torch.device(device_type, rank)
|
||||
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = master_port
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
|
|
@ -96,7 +90,7 @@ def only_on_rank(rank, sync=False):
|
|||
return decorator
|
||||
|
||||
|
||||
def _run_single_rank(
|
||||
def wrapper_spawn_func(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
|
|
@ -106,108 +100,20 @@ def _run_single_rank(
|
|||
func: Callable,
|
||||
kwargs: dict,
|
||||
):
|
||||
with setup_parallel(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
local_rank=rank,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
device_type=device_type,
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
|
||||
class LaunchStrategy(ABC):
|
||||
"""Strategy for launching a function in a distributed context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
device_type: str,
|
||||
start_method: str,
|
||||
):
|
||||
self.world_size = world_size
|
||||
self.backend = backend
|
||||
self.master_addr = master_addr
|
||||
self.master_port = master_port
|
||||
self.device_type = device_type
|
||||
self.start_method = start_method
|
||||
|
||||
@abstractmethod
|
||||
def launch(self, func: Callable, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TorchrunStrategy(LaunchStrategy):
|
||||
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
|
||||
|
||||
def launch(self, func: Callable, **kwargs):
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
||||
try:
|
||||
with setup_parallel(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
local_rank=local_rank,
|
||||
backend=self.backend,
|
||||
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
|
||||
master_port=os.environ.get("MASTER_PORT", self.master_port),
|
||||
device_type=self.device_type,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
device_type=device_type,
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
|
||||
class LocalStrategy(LaunchStrategy):
|
||||
"""Local launcher — single-process or mp.start_processes."""
|
||||
|
||||
def launch(self, func: Callable, **kwargs):
|
||||
args = (
|
||||
self.world_size,
|
||||
self.backend,
|
||||
self.master_addr,
|
||||
self.master_port,
|
||||
self.device_type,
|
||||
func,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if self.world_size == 1:
|
||||
_run_single_rank(0, *args)
|
||||
return
|
||||
|
||||
ctx = mp.start_processes(
|
||||
_run_single_rank,
|
||||
args=args,
|
||||
nprocs=self.world_size,
|
||||
start_method=self.start_method,
|
||||
join=False,
|
||||
)
|
||||
try:
|
||||
while not ctx.join():
|
||||
pass
|
||||
except BaseException:
|
||||
for p in ctx.processes:
|
||||
p.terminate()
|
||||
ctx.join()
|
||||
raise
|
||||
|
||||
|
||||
def _detect_launcher() -> str:
|
||||
"""Detect the distributed launcher from environment.
|
||||
|
||||
Returns one of: "torchelastic", "torchrun", "external", "local".
|
||||
"""
|
||||
if dist.is_torchelastic_launched():
|
||||
return "torchelastic"
|
||||
if "LOCAL_WORLD_SIZE" in os.environ:
|
||||
return "torchrun"
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
return "external"
|
||||
return "local"
|
||||
except Exception as e:
|
||||
print(f"Error in rank {rank}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def spawn_parallel_fn(
|
||||
|
|
@ -220,13 +126,41 @@ def spawn_parallel_fn(
|
|||
start_method: str = "spawn",
|
||||
**kwargs,
|
||||
):
|
||||
launcher = _detect_launcher()
|
||||
if launcher in ("torchelastic", "torchrun", "external"):
|
||||
strategy = TorchrunStrategy(
|
||||
world_size, backend, master_addr, master_port, device_type, start_method
|
||||
)
|
||||
else:
|
||||
strategy = LocalStrategy(
|
||||
world_size, backend, master_addr, master_port, device_type, start_method
|
||||
)
|
||||
strategy.launch(func, **kwargs)
|
||||
# clear environment variables
|
||||
for key in [
|
||||
"MASTER_ADDR",
|
||||
"MASTER_PORT",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"LOCAL_RANK",
|
||||
"LOCAL_DEVICE",
|
||||
]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
if world_size == 1:
|
||||
device_id = torch.device(device_type, 0)
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
wrapper_spawn_func_args = (
|
||||
world_size,
|
||||
backend,
|
||||
master_addr,
|
||||
master_port,
|
||||
device_type,
|
||||
func,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
mp.start_processes(
|
||||
wrapper_spawn_func,
|
||||
args=wrapper_spawn_func_args,
|
||||
nprocs=world_size,
|
||||
start_method=start_method,
|
||||
join=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
from astrai.preprocessing.builder import (
|
||||
BaseMaskBuilder,
|
||||
MaskBuilderFactory,
|
||||
SectionedMaskBuilder,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||
|
||||
__all__ = [
|
||||
"BaseMaskBuilder",
|
||||
"MaskBuilderFactory",
|
||||
"SectionedMaskBuilder",
|
||||
"Pipeline",
|
||||
"filter_by_length",
|
||||
]
|
||||
|
|
@ -1,338 +0,0 @@
|
|||
"""Mask building strategies for preprocessing pipeline.
|
||||
|
||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
|
||||
for single-output or ``input.sources`` for multi-output.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseMaskBuilder(ABC):
|
||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
||||
|
||||
Returns ``None`` to skip the item entirely.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseMaskBuilder):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||
)
|
||||
|
||||
|
||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||
if not domain_key:
|
||||
return "__default__"
|
||||
val = item.get(domain_key, "__default__")
|
||||
return val if isinstance(val, str) else "__default__"
|
||||
|
||||
|
||||
def _resolve_action(action: str, role: str, config) -> str:
|
||||
"""Resolve action to "train" or "mask".
|
||||
|
||||
- ``"train"`` / ``"mask"`` → literal
|
||||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||||
"""
|
||||
if action == "$role":
|
||||
return config.mask.get(role, config.mask_default)
|
||||
return action
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("sectioned")
|
||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||
"""Config-driven builder supporting single and multi-output modes.
|
||||
|
||||
Single-output (backward-compatible)::
|
||||
|
||||
{"input": {"sections": [
|
||||
{"field": "messages", "action": "$role", "template": true}
|
||||
]}}
|
||||
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
||||
|
||||
Multi-output (DPO / GRPO)::
|
||||
|
||||
{"input": {"sources": {
|
||||
"chosen": {"sections": [
|
||||
{"field": "chosen", "action": "$role", "template": true}
|
||||
]},
|
||||
"rejected": {"sections": [
|
||||
{"field": "rejected", "action": "$role", "template": true}
|
||||
]}
|
||||
}}}
|
||||
→ {"chosen": [...], "chosen_mask": [...],
|
||||
"rejected": [...], "rejected_mask": [...], "domain": "..."}
|
||||
|
||||
Output spec fields::
|
||||
|
||||
sections – list of section specs (same format as single-output)
|
||||
list_field – True when the JSONL field holds a list of values to
|
||||
tokenise individually and concatenate (GRPO responses)
|
||||
mask_key – explicit output key for the loss mask
|
||||
(default: ``"{output_key}_mask"``)
|
||||
dtype – explicit tensor dtype for this output key
|
||||
(default: "int32")
|
||||
"""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
sources_spec = getattr(config.input, "sources", None)
|
||||
if sources_spec:
|
||||
return self._build_multi(item, sources_spec, config, tokenizer)
|
||||
return self._build_single(item, config, tokenizer)
|
||||
|
||||
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
sections = config.input.sections
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
ids, mask = self._process_sections(
|
||||
item, sections, config, tokenizer, is_top_level=True
|
||||
)
|
||||
if ids is None:
|
||||
return None
|
||||
|
||||
result: dict = {
|
||||
"sequence": ids,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
if not all(m == 1 for m in mask):
|
||||
result["loss_mask"] = mask
|
||||
return result
|
||||
|
||||
def _build_multi(
|
||||
self, item: dict, sources_spec: dict, config, tokenizer
|
||||
) -> Optional[dict]:
|
||||
result: dict = {}
|
||||
any_output = False
|
||||
|
||||
for output_key, spec in sources_spec.items():
|
||||
sections = spec.get("sections", [])
|
||||
if not sections:
|
||||
continue
|
||||
|
||||
if self._is_value_section(sections):
|
||||
ids = self._extract_raw_value(item, sections)
|
||||
if ids is None:
|
||||
continue
|
||||
result[output_key] = ids
|
||||
any_output = True
|
||||
continue
|
||||
|
||||
list_field = spec.get("list_field", False)
|
||||
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
||||
|
||||
if list_field:
|
||||
ids, mask = self._process_list_field(item, sections, config, tokenizer)
|
||||
else:
|
||||
ids, mask = self._process_sections(
|
||||
item, sections, config, tokenizer, is_top_level=True
|
||||
)
|
||||
|
||||
if ids is None:
|
||||
continue
|
||||
|
||||
result[output_key] = ids
|
||||
if not all(m == 1 for m in mask):
|
||||
result[mask_key] = mask
|
||||
elif "mask_key" in spec:
|
||||
result[mask_key] = mask
|
||||
|
||||
any_output = True
|
||||
|
||||
if not any_output:
|
||||
return None
|
||||
|
||||
result["domain"] = _extract_domain(item, config.output.domain_key)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _is_value_section(sections: list) -> bool:
|
||||
return len(sections) == 1 and sections[0].get("action") == "value"
|
||||
|
||||
@staticmethod
|
||||
def _extract_raw_value(item: dict, sections: list):
|
||||
"""Extract a raw value from a JSONL field without tokenisation.
|
||||
|
||||
Used for GRPO rewards where the field contains float values.
|
||||
"""
|
||||
sec = sections[0]
|
||||
field = sec["field"]
|
||||
raw = item.get(field)
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, list):
|
||||
return [float(v) for v in raw]
|
||||
return [float(raw)]
|
||||
|
||||
def _process_sections(
|
||||
self,
|
||||
item: dict,
|
||||
sections: list,
|
||||
config,
|
||||
tokenizer,
|
||||
*,
|
||||
is_top_level: bool = False,
|
||||
):
|
||||
"""Process a list of sections into ``(ids, loss_mask)``.
|
||||
|
||||
Returns ``(None, None)`` if the item should be skipped.
|
||||
"""
|
||||
all_ids: list[int] = []
|
||||
loss_mask: list[int] = []
|
||||
|
||||
has_template = any(s.get("template") for s in sections)
|
||||
is_text_config = not has_template and all(
|
||||
s["action"] == "train" for s in sections
|
||||
)
|
||||
|
||||
if is_top_level and has_template and tokenizer.bos_token_id is not None:
|
||||
all_ids.append(tokenizer.bos_token_id)
|
||||
loss_mask.append(0)
|
||||
|
||||
first_section = True
|
||||
for sec in sections:
|
||||
field = sec["field"]
|
||||
action = sec["action"]
|
||||
use_template = sec.get("template", False)
|
||||
add_special = sec.get(
|
||||
"add_special_tokens", not use_template and first_section
|
||||
)
|
||||
|
||||
if use_template:
|
||||
success = self._append_template_section(
|
||||
item, field, action, tokenizer, config, all_ids, loss_mask
|
||||
)
|
||||
if not success:
|
||||
continue
|
||||
else:
|
||||
success = self._append_text_section(
|
||||
item,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
add_special,
|
||||
is_text_config,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
)
|
||||
if not success:
|
||||
continue
|
||||
|
||||
first_section = False
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
all_ids = all_ids[:max_len]
|
||||
loss_mask = loss_mask[: len(all_ids)]
|
||||
|
||||
if not all_ids:
|
||||
return None, None
|
||||
|
||||
if is_top_level and has_template and len(all_ids) <= 1:
|
||||
return None, None
|
||||
|
||||
return all_ids, loss_mask
|
||||
|
||||
def _append_template_section(
|
||||
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
||||
):
|
||||
messages = item.get(field)
|
||||
if not isinstance(messages, list) or not messages:
|
||||
return False
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
act = _resolve_action(action, role, config)
|
||||
rendered = tokenizer.apply_chat_template(
|
||||
[msg], tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||
all_ids.extend(ids)
|
||||
val = 1 if act == "train" else 0
|
||||
loss_mask.extend([val] * len(ids))
|
||||
return True
|
||||
|
||||
def _append_text_section(
|
||||
self,
|
||||
item,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
add_special,
|
||||
is_text_config,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
):
|
||||
text = str(item.get(field, ""))
|
||||
if not text.strip():
|
||||
return False
|
||||
if is_text_config:
|
||||
pp = config.preprocessing
|
||||
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||||
return False
|
||||
if len(text) > pp.max_chars:
|
||||
return False
|
||||
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||||
all_ids.extend(ids)
|
||||
val = 1 if action == "train" else 0
|
||||
loss_mask.extend([val] * len(ids))
|
||||
return True
|
||||
|
||||
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
|
||||
all_ids: list[int] = []
|
||||
loss_mask: list[int] = []
|
||||
|
||||
for sec in sections:
|
||||
field = sec["field"]
|
||||
action = sec["action"]
|
||||
use_template = sec.get("template", False)
|
||||
|
||||
values = item.get(field)
|
||||
if not isinstance(values, list):
|
||||
continue
|
||||
|
||||
for val in values:
|
||||
if use_template:
|
||||
if isinstance(val, list):
|
||||
wrapper = {field: val}
|
||||
self._append_template_section(
|
||||
wrapper,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
)
|
||||
else:
|
||||
wrapper = {field: str(val)}
|
||||
self._append_text_section(
|
||||
wrapper,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
False,
|
||||
False,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
)
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
all_ids = all_ids[:max_len]
|
||||
loss_mask = loss_mask[: len(all_ids)]
|
||||
|
||||
if not all_ids:
|
||||
return None, None
|
||||
return all_ids, loss_mask
|
||||
|
|
@ -1,257 +0,0 @@
|
|||
"""Config-driven JSONL preprocessing pipeline.
|
||||
|
||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||
sharding and flush to ``.h5`` / ``.bin`` storage.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.preprocessing.builder import SectionedMaskBuilder
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||
"bool": torch.bool,
|
||||
"uint8": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"int16": torch.int16,
|
||||
"int32": torch.int32,
|
||||
"int64": torch.int64,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
}
|
||||
|
||||
|
||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||
return min_len <= len(text) <= max_len
|
||||
|
||||
|
||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||
if len(seq) <= max_len:
|
||||
return seq
|
||||
if mode == "keep_end":
|
||||
return seq[-max_len:]
|
||||
return seq[:max_len]
|
||||
|
||||
|
||||
def pack_sequences(
|
||||
sequences: List[list],
|
||||
max_packed_len: int,
|
||||
strategy: str,
|
||||
truncation_mode: str,
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Pack *sequences* into bins and return a reorder plan.
|
||||
|
||||
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
|
||||
All keys (sequence, loss_mask, …) must be reordered and truncated
|
||||
identically according to this plan.
|
||||
|
||||
Supported *strategy* values:
|
||||
|
||||
- ``"simple"``: sequential, no reordering.
|
||||
- ``"bfd"``: best-fit decreasing bin packing.
|
||||
"""
|
||||
n = len(sequences)
|
||||
if strategy == "simple":
|
||||
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
|
||||
|
||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||
bins: List[List[int]] = []
|
||||
bin_lengths: List[int] = []
|
||||
|
||||
for orig_idx in order:
|
||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||
|
||||
best_bin = None
|
||||
best_remain = max_packed_len + 1
|
||||
for i, bl in enumerate(bin_lengths):
|
||||
remain = max_packed_len - bl
|
||||
if seq_len <= remain < best_remain:
|
||||
best_remain = remain
|
||||
best_bin = i
|
||||
|
||||
if best_bin is not None:
|
||||
bins[best_bin].append(orig_idx)
|
||||
bin_lengths[best_bin] += seq_len
|
||||
else:
|
||||
bins.append([orig_idx])
|
||||
bin_lengths.append(seq_len)
|
||||
|
||||
plan: List[Tuple[int, int]] = []
|
||||
for bin_indices in bins:
|
||||
for orig_idx in bin_indices:
|
||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||
|
||||
return plan
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||
|
||||
Usage::
|
||||
|
||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
input_paths: list[str],
|
||||
output_dir: str,
|
||||
tokenizer_path: str,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.config = config
|
||||
self.paths = input_paths
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
self.mask_builder = SectionedMaskBuilder()
|
||||
|
||||
def transform(self, item: dict) -> Optional[dict]:
|
||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||
|
||||
def run(self):
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||
domains: dict = defaultdict(lambda: defaultdict(list))
|
||||
total_tokens = 0
|
||||
shard_idx: dict[str, int] = defaultdict(int)
|
||||
count = 0
|
||||
|
||||
pp = self.config.preprocessing
|
||||
|
||||
for item in tqdm.tqdm(
|
||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
||||
):
|
||||
if pp.max_items and count >= pp.max_items:
|
||||
break
|
||||
|
||||
result = self.transform(item)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
domain = result.pop("domain", "__default__")
|
||||
|
||||
is_multi = bool(getattr(self.config.input, "sources", None))
|
||||
if is_multi:
|
||||
ids = self._primary_ids(result)
|
||||
else:
|
||||
ids = result.pop("sequence")
|
||||
result["sequence"] = ids
|
||||
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
bucket = domains[domain]
|
||||
self._align_bucket(bucket, result, ids, is_multi)
|
||||
for key, val in result.items():
|
||||
bucket[key].append(val)
|
||||
|
||||
count += 1
|
||||
total_tokens += len(ids)
|
||||
|
||||
if total_tokens >= self.config.output.max_tokens_per_shard:
|
||||
self._flush(domains, shard_idx)
|
||||
domains.clear()
|
||||
total_tokens = 0
|
||||
|
||||
if total_tokens > 0:
|
||||
self._flush(domains, shard_idx)
|
||||
|
||||
print(f"Done. {count} documents tokenized.")
|
||||
|
||||
@staticmethod
|
||||
def _primary_ids(result: dict) -> list:
|
||||
"""Return the first list-valued entry in *result* as the primary id
|
||||
sequence for token counting."""
|
||||
for val in result.values():
|
||||
if isinstance(val, list) and val and isinstance(val[0], int):
|
||||
return val
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
|
||||
"""Pad previously-accumulated keys that are missing from *result*."""
|
||||
for key in list(bucket.keys()):
|
||||
if key in result:
|
||||
continue
|
||||
if is_multi:
|
||||
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
|
||||
bucket[key].append(pad)
|
||||
else:
|
||||
bucket[key].append([1] * len(ids))
|
||||
|
||||
def _iter_items(self):
|
||||
for path in self.paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
yield json.loads(line)
|
||||
|
||||
def _flush(self, domains, shard_idx):
|
||||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
|
||||
pp = self.config.preprocessing
|
||||
if pp.packing_strategy != "simple" and "sequence" in keys:
|
||||
plan = pack_sequences(
|
||||
keys["sequence"],
|
||||
pp.max_packed_len,
|
||||
pp.packing_strategy,
|
||||
pp.truncation_mode,
|
||||
)
|
||||
reordered = defaultdict(list)
|
||||
for orig_idx, truncated_len in plan:
|
||||
for k, vals in keys.items():
|
||||
reordered[k].append(
|
||||
_truncate(
|
||||
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
|
||||
)
|
||||
)
|
||||
keys = reordered
|
||||
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
dt = _STR_TO_DTYPE.get(
|
||||
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||
)
|
||||
tensors[key] = [
|
||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||
]
|
||||
|
||||
pid_mode = self.config.output.position_ids_mode
|
||||
if pid_mode and pid_mode != "none" and "sequence" in tensors:
|
||||
pos_ids = []
|
||||
if pid_mode == "doc_reset":
|
||||
for item in keys["sequence"]:
|
||||
pos_ids.extend(range(len(item)))
|
||||
else:
|
||||
total = sum(len(item) for item in keys["sequence"])
|
||||
pos_ids = list(range(total))
|
||||
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
||||
|
||||
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
|
||||
fmt = self.config.output.storage_format
|
||||
if fmt == "bin":
|
||||
save_bin(shard_path, tensors)
|
||||
else:
|
||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||
shard_idx[domain] = idx + 1
|
||||
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
|
||||
tqdm.tqdm.write(
|
||||
f" saved {domain}/shard_{idx:04d} "
|
||||
f"({tensors[first_key][0].numel():,} tokens)"
|
||||
)
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OptimizerProtocol(Protocol):
|
||||
def step(self, closure=None): ...
|
||||
def zero_grad(self): ...
|
||||
@property
|
||||
def param_groups(self) -> Any: ...
|
||||
def state_dict(self) -> dict: ...
|
||||
def load_state_dict(self, d: dict): ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SchedulerProtocol(Protocol):
|
||||
def step(self): ...
|
||||
def state_dict(self) -> dict: ...
|
||||
def load_state_dict(self, d: dict): ...
|
||||
def get_last_lr(self): ...
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
import io
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
|
|
@ -11,172 +9,75 @@ import torch.distributed as dist
|
|||
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
_META_FILE = "meta.json"
|
||||
_CONFIG_FILE = "config.json"
|
||||
_WEIGHTS_FILE = "model.safetensors"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
||||
st.save_file(state_dict, str(path))
|
||||
|
||||
|
||||
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return st.load_file(str(path))
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = st.load_file(str(path))
|
||||
else:
|
||||
state_dict = {}
|
||||
tmp = [state_dict]
|
||||
dist.broadcast_object_list(tmp, src=0)
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def save_json(data: dict, path: Union[str, Path]):
|
||||
with open(str(path), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
with open(str(path), "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
with open(str(path), "r") as f:
|
||||
data = json.load(f)
|
||||
else:
|
||||
data = {}
|
||||
tmp = [data]
|
||||
dist.broadcast_object_list(tmp, src=0)
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def save_torch(obj: Any, path: Union[str, Path]):
|
||||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
|
||||
path = Path(path)
|
||||
rank = get_rank()
|
||||
|
||||
if rank == 0:
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
||||
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
||||
else:
|
||||
num_bytes = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
dist.broadcast(num_bytes, src=0)
|
||||
|
||||
if rank != 0:
|
||||
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
||||
|
||||
dist.broadcast(data_tensor, src=0)
|
||||
|
||||
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
||||
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str):
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
path = Path(path)
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return load_safetensors(path)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = load_safetensors(path)
|
||||
specs = [
|
||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||
for k in sorted(state_dict)
|
||||
]
|
||||
else:
|
||||
state_dict = {}
|
||||
specs = []
|
||||
|
||||
specs_list = [specs]
|
||||
dist.broadcast_object_list(specs_list, src=0)
|
||||
specs = specs_list[0]
|
||||
|
||||
for key, shape, dtype_name in specs:
|
||||
dtype = getattr(torch, dtype_name)
|
||||
if rank != 0:
|
||||
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
||||
else:
|
||||
tensor = state_dict[key].contiguous().cpu()
|
||||
dist.broadcast(tensor, src=0)
|
||||
if rank != 0:
|
||||
state_dict[key] = tensor
|
||||
return state_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
state_dict: Dict[str, Any] = field(default_factory=dict)
|
||||
epoch: int = 0
|
||||
iteration: int = 0
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
config: Dict[str, Any] = field(default_factory=dict)
|
||||
def __init__(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.state_dict = state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.extra = extra or {}
|
||||
self.meta = meta or {}
|
||||
|
||||
def save(
|
||||
self,
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
|
||||
def save(self, save_dir: str):
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if get_rank() != 0:
|
||||
return
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
}
|
||||
meta.update(self.meta)
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
**self.meta,
|
||||
}
|
||||
save_json(meta, save_path / _META_FILE)
|
||||
save_json(self.config, save_path / _CONFIG_FILE)
|
||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||
for key, value in self.extra.items():
|
||||
save_torch(value, save_path / f"{key}.pt")
|
||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||
if self.extra:
|
||||
for key, value in self.extra.items():
|
||||
torch.save(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||
def load(
|
||||
cls,
|
||||
save_dir: str,
|
||||
) -> "Checkpoint":
|
||||
|
||||
rank = get_rank()
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = load_json(save_path / _META_FILE, broadcast)
|
||||
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
||||
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
||||
meta = {}
|
||||
if rank == 0:
|
||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||
|
||||
extra = {}
|
||||
for f in sorted(save_path.iterdir()):
|
||||
if f.suffix == ".pt":
|
||||
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
||||
for f in save_path.iterdir():
|
||||
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
epoch=meta.get("epoch", 0),
|
||||
iteration=meta.get("iteration", 0),
|
||||
extra=extra,
|
||||
config=config,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
extra=extra or None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
# Message type for chat messages
|
||||
type MessageType = Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplate:
|
||||
"""A chat template with Jinja2 rendering support.
|
||||
|
||||
|
|
@ -12,24 +15,23 @@ class ChatTemplate:
|
|||
name: Unique identifier for the template.
|
||||
template_str: Jinja2 template string.
|
||||
description: Optional description.
|
||||
default_variables: Optional dictionary of default variable values.
|
||||
default_variables: Optional dictionary of default variable values
|
||||
that will be passed to the template if not overridden during rendering.
|
||||
special_tokens: Optional dictionary mapping token names to their string values.
|
||||
These tokens are automatically added to the template variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "",
|
||||
template_str: str = "",
|
||||
description: str = "",
|
||||
default_variables: Optional[Dict[str, Any]] = None,
|
||||
special_tokens: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.template_str = template_str
|
||||
self.description = description
|
||||
self.default_variables = default_variables or {}
|
||||
self.special_tokens = special_tokens or {}
|
||||
self._compiled: Template = Template(template_str)
|
||||
name: str
|
||||
template_str: str
|
||||
description: str = ""
|
||||
default_variables: Dict[str, Any] = None
|
||||
special_tokens: Dict[str, str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.default_variables is None:
|
||||
self.default_variables = {}
|
||||
if self.special_tokens is None:
|
||||
self.special_tokens = {}
|
||||
|
||||
@classmethod
|
||||
def from_string(
|
||||
|
|
@ -41,7 +43,7 @@ class ChatTemplate:
|
|||
) -> "ChatTemplate":
|
||||
"""Create a ChatTemplate instance directly from a template string."""
|
||||
return cls(
|
||||
name="",
|
||||
name="", # empty name for ad‑hoc templates
|
||||
template_str=template_str,
|
||||
description=description,
|
||||
default_variables=default_variables,
|
||||
|
|
@ -71,4 +73,5 @@ class ChatTemplate:
|
|||
if system_prompt is not None:
|
||||
variables["system_prompt"] = system_prompt
|
||||
|
||||
return self._compiled.render(**variables)
|
||||
jinja_template = Template(self.template_str)
|
||||
return jinja_template.render(**variables)
|
||||
|
|
|
|||
|
|
@ -4,17 +4,17 @@ from torch.optim import Optimizer
|
|||
|
||||
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||
assert G.ndim == 2
|
||||
X = G
|
||||
X = G.bfloat16()
|
||||
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||
X = X / (X.norm() + 1e-7) * scale
|
||||
if steps == 0:
|
||||
return X
|
||||
return X.type_as(G)
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = A @ X
|
||||
X = a * X + b * B + c * (A @ B)
|
||||
return X
|
||||
return X.type_as(G)
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
|
|
@ -50,94 +50,64 @@ class Muon(Optimizer):
|
|||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_2d, params_1d = [], []
|
||||
grads_2d, grads_1d = [], []
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if p.grad.is_sparse:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Muon does not support sparse gradients")
|
||||
if p.ndim >= 2:
|
||||
params_2d.append(p)
|
||||
grads_2d.append(p.grad)
|
||||
self._muon_update(p, grad, group)
|
||||
else:
|
||||
params_1d.append(p)
|
||||
grads_1d.append(p.grad)
|
||||
|
||||
if params_2d:
|
||||
self._muon_update_foreach(params_2d, grads_2d, group)
|
||||
if params_1d:
|
||||
self._adamw_update_foreach(params_1d, grads_1d, group)
|
||||
|
||||
self._adamw_update(p, grad, group)
|
||||
return loss
|
||||
|
||||
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
||||
def _muon_update(self, p, grad, group):
|
||||
lr = group["lr"]
|
||||
momentum = group["momentum"]
|
||||
wd = group["weight_decay"]
|
||||
nesterov = group["nesterov"]
|
||||
ns_steps = group["ns_steps"]
|
||||
state = self.state[p]
|
||||
|
||||
if wd != 0:
|
||||
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
||||
p.mul_(1 - lr * wd)
|
||||
|
||||
if nesterov:
|
||||
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
||||
grad = grad.add(p, alpha=wd)
|
||||
|
||||
bufs = []
|
||||
for p, grad in zip(params_2d, grads_2d):
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||
bufs.append(state["momentum_buffer"])
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||
buf = state["momentum_buffer"]
|
||||
buf.lerp_(grad, 1 - momentum)
|
||||
|
||||
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||
p.add_(update, alpha=-lr * scale)
|
||||
|
||||
for p, buf in zip(params_2d, bufs):
|
||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||
p.add_(update, alpha=-lr * scale)
|
||||
|
||||
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
||||
def _adamw_update(self, p, grad, group):
|
||||
lr = group["adamw_lr"]
|
||||
betas = group["adamw_betas"]
|
||||
eps = group["adamw_eps"]
|
||||
wd = group["adamw_wd"]
|
||||
state = self.state[p]
|
||||
|
||||
steps: list[int] = []
|
||||
exp_avgs, exp_avg_sqs = [], []
|
||||
has_state = []
|
||||
for p in params_1d:
|
||||
state = self.state[p]
|
||||
if not state:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
has_state.append(False)
|
||||
else:
|
||||
has_state.append(True)
|
||||
state["step"] += 1
|
||||
steps.append(state["step"])
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
if not state:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
state["step"] += 1
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = betas
|
||||
|
||||
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
||||
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
||||
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
||||
exp_avg.lerp_(grad, 1 - beta1)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
|
||||
|
||||
bias_correction1 = [1 - beta1**s for s in steps]
|
||||
bias_correction2 = [1 - beta2**s for s in steps]
|
||||
step = state["step"]
|
||||
bias1 = 1 - beta1**step
|
||||
bias2 = 1 - beta2**step
|
||||
|
||||
if wd != 0:
|
||||
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
||||
|
||||
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
||||
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
||||
denom = torch._foreach_sqrt(denom)
|
||||
torch._foreach_add_(denom, eps)
|
||||
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|
||||
p.mul_(1 - lr * wd)
|
||||
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
|
||||
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Training strategy implementations with factory pattern."""
|
||||
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
|
|
@ -7,14 +8,26 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
||||
"""Create a frozen reference model from model_fn + full state dict."""
|
||||
ref_model = model_fn()
|
||||
ref_model.load_state_dict(state_dict)
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Unwrap DDP wrapper if present to get the original model."""
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||
"""Create a reference model for DPO/GRPO training.
|
||||
|
||||
Handles DDP-wrapped models safely by unwrapping first,
|
||||
then creating a deep copy with frozen gradients.
|
||||
"""
|
||||
original_model = unwrap_model(model)
|
||||
ref_model = copy.deepcopy(original_model)
|
||||
ref_model.requires_grad_(False)
|
||||
ref_model.eval()
|
||||
return ref_model
|
||||
|
|
@ -68,22 +81,6 @@ def get_logprobs(
|
|||
return token_logprobs * shifted_mask
|
||||
|
||||
|
||||
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
|
||||
S = position_ids.size(1)
|
||||
device = position_ids.device
|
||||
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
|
||||
doc_ids = torch.cat(
|
||||
[
|
||||
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
|
||||
boundaries.long().cumsum(dim=1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
|
||||
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
|
||||
return (same_doc & causal).unsqueeze(1)
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""Abstract base class for training strategies."""
|
||||
|
||||
|
|
@ -92,8 +89,6 @@ class BaseStrategy(ABC):
|
|||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.executor = kwargs.pop("executor", None)
|
||||
self.model_fn = kwargs.pop("model_fn", None)
|
||||
self.extra_kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -128,7 +123,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, strategy_cls: type):
|
||||
def _validate_component(cls, strategy_cls: type) -> None:
|
||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
|
|
@ -196,19 +191,15 @@ class SFTStrategy(BaseStrategy):
|
|||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
batch = move_to_device(batch, self.device)
|
||||
input_ids, target_ids, position_ids, loss_mask = (
|
||||
input_ids, target_ids, loss_mask = (
|
||||
batch["input_ids"],
|
||||
batch["target_ids"],
|
||||
batch["position_ids"],
|
||||
batch["loss_mask"],
|
||||
)
|
||||
|
||||
ignore_index = -100
|
||||
input_mask = make_doc_boundary_mask(position_ids)
|
||||
logits = self.model(input_ids=input_ids)["logits"]
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
logits = self.model(
|
||||
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
|
||||
)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
|
|
@ -237,9 +228,7 @@ class DPOStrategy(BaseStrategy):
|
|||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
).to(device=self.device)
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
|
||||
|
|
@ -293,9 +282,7 @@ class GRPOStrategy(BaseStrategy):
|
|||
**kwargs,
|
||||
):
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.ref_model = create_ref_model(
|
||||
self.model_fn, self.executor.unwrap_model(model)
|
||||
).to(device=self.device)
|
||||
self.ref_model = create_ref_model(model)
|
||||
self.clip_eps = clip_eps
|
||||
self.kl_coef = kl_coef
|
||||
self.group_size = group_size
|
||||
|
|
@ -305,7 +292,8 @@ class GRPOStrategy(BaseStrategy):
|
|||
|
||||
def sync_ref_model(self):
|
||||
"""Copy current model weights to ref model."""
|
||||
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
||||
ref_state = self.model.state_dict()
|
||||
self.ref_model.load_state_dict(ref_state)
|
||||
|
||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
self._step += 1
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tqdm import tqdm
|
|||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.parallel.setup import get_current_device, get_rank
|
||||
from astrai.parallel.setup import get_current_device
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
|
|
@ -51,15 +51,18 @@ class TrainCallback(Protocol):
|
|||
def on_epoch_end(self, context: TrainContext):
|
||||
"""Called at the end of each epoch."""
|
||||
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
"""Called at the beginning of each step."""
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
"""Called at the end of each step."""
|
||||
|
||||
def on_batch_begin(self, context: TrainContext):
|
||||
"""Called at the beginning of each batch."""
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
"""Called at the end of each batch."""
|
||||
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
"""Called on every optimizer step (sync step only)."""
|
||||
|
||||
def on_error(self, context: TrainContext):
|
||||
"""Called when an error occurs during training."""
|
||||
|
||||
|
|
@ -85,7 +88,7 @@ class GradientClippingCallback(TrainCallback):
|
|||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
|
|
@ -137,31 +140,44 @@ class CheckpointCallback(TrainCallback):
|
|||
save_dir: str,
|
||||
interval: int,
|
||||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_checkpoint(self, context: TrainContext):
|
||||
state_dict = context.executor.unwrap_model(context.model)
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
state_dict = (
|
||||
self.state_dict_fn(context.model)
|
||||
if self.state_dict_fn
|
||||
else context.model.state_dict()
|
||||
)
|
||||
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
meta=context.config.to_dict(),
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
if get_rank() == 0:
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
config=context.model_config,
|
||||
)
|
||||
context.checkpoint.save(save_path)
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
self.load_extra_fn(context.checkpoint.extra, context)
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
|
|
@ -183,6 +199,12 @@ class CheckpointCallback(TrainCallback):
|
|||
extra[name] = obj.state_dict()
|
||||
return extra
|
||||
|
||||
@staticmethod
|
||||
def load_extra(extra: dict, context: TrainContext):
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
if name in extra:
|
||||
getattr(context, name).load_state_dict(extra[name])
|
||||
|
||||
|
||||
@CallbackFactory.register("progress_bar")
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
|
|
@ -191,7 +213,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
||||
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
||||
):
|
||||
self.num_epoch = num_epoch
|
||||
self.log_interval = log_interval
|
||||
|
|
@ -204,7 +226,7 @@ class ProgressBarCallback(TrainCallback):
|
|||
context.dataloader,
|
||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||
dynamic_ncols=True,
|
||||
file=self.file or sys.stdout,
|
||||
file=self.file,
|
||||
)
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
@ -322,7 +344,7 @@ class ValidationCallback(TrainCallback):
|
|||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||
)
|
||||
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
def on_step_end(self, context: TrainContext):
|
||||
if context.val_dataloader is None:
|
||||
return
|
||||
cfg = context.config
|
||||
|
|
|
|||
|
|
@ -1,18 +1,15 @@
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.dataset import ResumableDistributedSampler
|
||||
from astrai.model.components.lora import inject_lora
|
||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
||||
|
|
@ -21,12 +18,10 @@ class TrainContext:
|
|||
model: nn.Module = field(default=None)
|
||||
strategy: BaseStrategy = field(default=None)
|
||||
dataloader: DataLoader = field(default=None)
|
||||
optimizer: OptimizerProtocol = field(default=None)
|
||||
scheduler: SchedulerProtocol = field(default=None)
|
||||
optimizer: Optimizer = field(default=None)
|
||||
scheduler: LRScheduler = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
config: TrainConfig = field(default=None)
|
||||
model_config: dict = field(default_factory=dict)
|
||||
executor: BaseExecutor = field(default=None)
|
||||
|
||||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
|
|
@ -45,91 +40,49 @@ class TrainContextBuilder:
|
|||
config: TrainConfig,
|
||||
):
|
||||
self.config = config
|
||||
self._resume_dir: Optional[str] = None
|
||||
self._checkpoint: Optional[Checkpoint] = None
|
||||
|
||||
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
||||
self._resume_dir = resume_dir
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
cfg = self.config
|
||||
device = get_current_device()
|
||||
|
||||
executor = ExecutorFactory.create(
|
||||
cfg.parallel_mode,
|
||||
grad_accum_steps=cfg.grad_accum_steps,
|
||||
**cfg.executor_kwargs,
|
||||
)
|
||||
|
||||
model = cfg.model_fn()
|
||||
model = model.to(device=device)
|
||||
|
||||
model_config = {}
|
||||
if self._resume_dir:
|
||||
config_path = Path(self._resume_dir) / "config.json"
|
||||
if config_path.exists():
|
||||
model_config = load_json(config_path)
|
||||
|
||||
if not model_config and hasattr(model, "config"):
|
||||
model_config = model.config.to_dict()
|
||||
|
||||
context = TrainContext(
|
||||
model=model,
|
||||
model=self.config.model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=cfg,
|
||||
model_config=model_config,
|
||||
executor=executor,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
if self._resume_dir is not None:
|
||||
resume_path = Path(self._resume_dir)
|
||||
if (resume_path / "meta.json").exists():
|
||||
checkpoint = Checkpoint.load(self._resume_dir)
|
||||
state_dict = checkpoint.state_dict
|
||||
if checkpoint.config:
|
||||
context.model_config = checkpoint.config
|
||||
else:
|
||||
checkpoint = None
|
||||
state_dict = load_model_weights(self._resume_dir)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if checkpoint is not None:
|
||||
context.epoch = cfg.start_epoch
|
||||
context.iteration = cfg.start_batch
|
||||
context.checkpoint = checkpoint
|
||||
device = get_current_device()
|
||||
context.model = context.model.to(device=device)
|
||||
|
||||
if cfg.lora is not None:
|
||||
inject_lora(
|
||||
model,
|
||||
r=cfg.lora.r,
|
||||
alpha=cfg.lora.alpha,
|
||||
target_modules=set(cfg.lora.target_modules),
|
||||
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
||||
context.model = self.config.parallel_wrapper(context.model)
|
||||
|
||||
if self._checkpoint is not None:
|
||||
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
context.checkpoint = self._checkpoint
|
||||
else:
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=context.model.state_dict(),
|
||||
)
|
||||
|
||||
context.optimizer = cfg.optimizer_fn(model)
|
||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||
|
||||
train_dataset = cfg.dataset
|
||||
val_dataset = cfg.val_dataset
|
||||
|
||||
if val_dataset is None and cfg.val_split is not None:
|
||||
n_total = len(cfg.dataset)
|
||||
n_val = max(1, int(n_total * cfg.val_split))
|
||||
n_train = n_total - n_val
|
||||
generator = torch.Generator().manual_seed(cfg.random_seed)
|
||||
train_dataset, val_dataset = random_split(
|
||||
cfg.dataset, [n_train, n_val], generator=generator
|
||||
)
|
||||
context.optimizer = self.config.optimizer_fn(context.model)
|
||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||
|
||||
cfg = self.config
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
sampler = ResumableDistributedSampler(
|
||||
data_source=train_dataset,
|
||||
data_source=cfg.dataset,
|
||||
start_epoch=context.epoch,
|
||||
start_iter=sampler_offset,
|
||||
seed=cfg.random_seed,
|
||||
)
|
||||
context.dataloader = DataLoader(
|
||||
train_dataset,
|
||||
cfg.dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
|
|
@ -137,16 +90,16 @@ class TrainContextBuilder:
|
|||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
if val_dataset is not None:
|
||||
if cfg.val_dataset is not None:
|
||||
val_sampler = ResumableDistributedSampler(
|
||||
data_source=val_dataset,
|
||||
data_source=cfg.val_dataset,
|
||||
start_epoch=0,
|
||||
start_iter=0,
|
||||
seed=cfg.random_seed,
|
||||
shuffle=False,
|
||||
)
|
||||
context.val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
cfg.val_dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=val_sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
|
|
@ -154,30 +107,11 @@ class TrainContextBuilder:
|
|||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||
executor.prepare(
|
||||
model,
|
||||
context.optimizer,
|
||||
context.dataloader,
|
||||
context.scheduler,
|
||||
)
|
||||
)
|
||||
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
extra = context.checkpoint.extra
|
||||
for name in ("optimizer", "scheduler"):
|
||||
if name in extra:
|
||||
obj = getattr(context, name, None)
|
||||
if obj is not None:
|
||||
obj.load_state_dict(extra[name])
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=cfg.strategy,
|
||||
train_type=self.config.strategy,
|
||||
device=device,
|
||||
executor=executor,
|
||||
model_fn=cfg.model_fn,
|
||||
**cfg.extra_kwargs,
|
||||
**self.config.extra_kwargs,
|
||||
)
|
||||
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import List, Optional
|
|||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.train_callback import (
|
||||
CallbackFactory,
|
||||
TrainCallback,
|
||||
|
|
@ -33,6 +34,7 @@ class Trainer:
|
|||
"checkpoint",
|
||||
cfg.ckpt_dir,
|
||||
cfg.ckpt_interval,
|
||||
state_dict_fn=cfg.state_dict_fn,
|
||||
),
|
||||
CallbackFactory.create(
|
||||
"metric_logger",
|
||||
|
|
@ -53,49 +55,47 @@ class Trainer:
|
|||
if method:
|
||||
method(context)
|
||||
|
||||
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
||||
context = (
|
||||
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
||||
)
|
||||
executor = context.executor
|
||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
||||
try:
|
||||
context.model.train()
|
||||
grad_accum_steps = cfg.grad_accum_steps
|
||||
|
||||
for epoch in range(context.epoch, context.config.n_epoch):
|
||||
for epoch in range(context.epoch, cfg.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / grad_accum_steps
|
||||
stand_loss.backward()
|
||||
context.iteration += 1
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
with executor.accumulate(context.model):
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / executor.grad_accum_steps
|
||||
executor.backward(stand_loss)
|
||||
context.iteration += 1
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
if context.iteration % grad_accum_steps == 0:
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
if executor.sync_gradients:
|
||||
self._call_callbacks("on_optimizer_step", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Training failed: %s", str(e), exc_info=True)
|
||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||
self._call_callbacks("on_error", context)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks("on_train_end", context)
|
||||
|
||||
def train(self, resume_dir: Optional[str] = None):
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._trainer_loop,
|
||||
|
|
@ -105,5 +105,5 @@ class Trainer:
|
|||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
resume_dir=resume_dir,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,336 +0,0 @@
|
|||
"""HumanEval code generation benchmark.
|
||||
|
||||
Generates n completions per problem, extracts function bodies, executes
|
||||
against hidden tests, and computes pass@k.
|
||||
|
||||
Usage::
|
||||
|
||||
python scripts/tools/evaluate_humaneval.py --param_path ./params \
|
||||
--data_path HumanEval.jsonl.gz --output results.json \
|
||||
--num_samples 200 --temperature 0.8 --max_tokens 512
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
from math import prod
|
||||
from multiprocessing import Process, Queue
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.inference import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
HUMANEVAL_URL = (
|
||||
"https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz"
|
||||
)
|
||||
|
||||
_STOP_SEQUENCES = [
|
||||
"\nclass ",
|
||||
"\ndef ",
|
||||
"\n# ",
|
||||
"\nif __name__",
|
||||
"\nprint(",
|
||||
"\n\n\n",
|
||||
]
|
||||
|
||||
|
||||
def _download_humaneval(data_path: str):
|
||||
if os.path.exists(data_path):
|
||||
return
|
||||
import gzip
|
||||
import urllib.request
|
||||
|
||||
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
|
||||
print(f"Downloading HumanEval from {HUMANEVAL_URL} ...")
|
||||
tmp = data_path + ".tmp"
|
||||
urllib.request.urlretrieve(HUMANEVAL_URL, tmp)
|
||||
with gzip.open(tmp, "rb") as f_in:
|
||||
with open(data_path, "wb") as f_out:
|
||||
f_out.write(f_in.read())
|
||||
os.remove(tmp)
|
||||
print(f" saved to {data_path}")
|
||||
|
||||
|
||||
def _load_problems(data_path: str) -> List[dict]:
|
||||
problems = []
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
problems.append(json.loads(line))
|
||||
return problems
|
||||
|
||||
|
||||
def _extract_function_body(code: str, entry_point: str) -> Optional[str]:
|
||||
"""Extract the function body from a completion."""
|
||||
pattern = rf"def\s+{re.escape(entry_point)}\b[^:]*:"
|
||||
match = re.search(pattern, code)
|
||||
if not match:
|
||||
# Use the full code as-is if we can't find the function
|
||||
return code
|
||||
|
||||
body_start = match.end()
|
||||
lines = code[body_start:].split("\n")
|
||||
body_lines = []
|
||||
started = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.rstrip()
|
||||
if not stripped and not started:
|
||||
continue
|
||||
if not stripped and started:
|
||||
body_lines.append("")
|
||||
continue
|
||||
if not started:
|
||||
started = True
|
||||
if stripped.lstrip() == stripped and started:
|
||||
break
|
||||
body_lines.append(stripped)
|
||||
|
||||
body = "\n".join(body_lines)
|
||||
if not body.strip():
|
||||
return None
|
||||
return body
|
||||
|
||||
|
||||
def _trim_stop_sequences(text: str) -> str:
|
||||
for stop in _STOP_SEQUENCES:
|
||||
idx = text.find(stop)
|
||||
if idx != -1:
|
||||
text = text[:idx]
|
||||
return text
|
||||
|
||||
|
||||
def _execute_code(problem: dict, completion: str, timeout: float = 3.0) -> bool:
|
||||
"""Run the completion against hidden tests in a subprocess."""
|
||||
|
||||
def _worker(queue, full_code):
|
||||
try:
|
||||
namespace = {}
|
||||
exec(full_code, namespace)
|
||||
check = namespace.get("check")
|
||||
if check is None:
|
||||
queue.put(False)
|
||||
return
|
||||
check(namespace.get(problem["entry_point"]))
|
||||
queue.put(True)
|
||||
except Exception:
|
||||
queue.put(False)
|
||||
|
||||
full_code = problem["prompt"] + completion + "\n" + problem["test"]
|
||||
|
||||
queue: Queue = Queue()
|
||||
proc = Process(target=_worker, args=(queue, full_code))
|
||||
proc.start()
|
||||
proc.join(timeout)
|
||||
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join()
|
||||
return False
|
||||
|
||||
try:
|
||||
return queue.get_nowait()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _pass_at_k(n: int, c: int, k: int) -> float:
|
||||
"""Unbiased estimator of pass@k."""
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
return 1.0 - float(prod(1.0 - k / np.arange(n - c + 1, n + 1)))
|
||||
|
||||
|
||||
def _deduplicate(completions: List[str]) -> List[str]:
|
||||
seen = set()
|
||||
unique = []
|
||||
for c in completions:
|
||||
if c not in seen:
|
||||
seen.add(c)
|
||||
unique.append(c)
|
||||
return unique
|
||||
|
||||
|
||||
def _generate(
|
||||
engine: InferenceEngine,
|
||||
prompt: str,
|
||||
num_samples: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
batch_size: int,
|
||||
) -> List[str]:
|
||||
batches = [prompt] * min(batch_size, num_samples)
|
||||
completions = []
|
||||
remaining = num_samples
|
||||
|
||||
while remaining > 0:
|
||||
current = min(batch_size, remaining)
|
||||
batch_prompts = batches[:current]
|
||||
outputs = engine.generate(
|
||||
prompt=batch_prompts,
|
||||
stream=False,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
if isinstance(outputs, str):
|
||||
outputs = [outputs]
|
||||
completions.extend(outputs)
|
||||
remaining -= current
|
||||
|
||||
return _deduplicate(completions)
|
||||
|
||||
|
||||
def evaluate(
|
||||
engine: InferenceEngine,
|
||||
problems: List[dict],
|
||||
num_samples: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
batch_size: int,
|
||||
k_values: Tuple[int, ...] = (1, 10, 100),
|
||||
) -> Dict:
|
||||
results = {}
|
||||
all_pass_at_k = {k: [] for k in k_values}
|
||||
|
||||
for problem in tqdm.tqdm(problems, desc="HumanEval", unit="problem"):
|
||||
task_id = problem["task_id"]
|
||||
prompt = problem["prompt"]
|
||||
entry_point = problem["entry_point"]
|
||||
|
||||
raw_completions = _generate(
|
||||
engine,
|
||||
prompt,
|
||||
num_samples,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
completions = []
|
||||
for raw in raw_completions:
|
||||
trimmed = _trim_stop_sequences(raw)
|
||||
body = _extract_function_body(trimmed, entry_point)
|
||||
if body:
|
||||
completions.append(body)
|
||||
|
||||
passed = 0
|
||||
for comp in completions:
|
||||
if _execute_code(problem, comp):
|
||||
passed += 1
|
||||
|
||||
n = len(completions)
|
||||
c = passed
|
||||
result = {"task_id": task_id, "n": n, "passed": c}
|
||||
for k in k_values:
|
||||
result[f"pass@{k}"] = round(_pass_at_k(n, c, k), 4)
|
||||
all_pass_at_k[k].append(_pass_at_k(n, c, k))
|
||||
results[task_id] = result
|
||||
|
||||
summary = {}
|
||||
for k in k_values:
|
||||
vals = all_pass_at_k[k]
|
||||
summary[f"pass@{k}"] = round(float(np.mean(vals)), 4)
|
||||
results["_summary"] = summary
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="HumanEval benchmark")
|
||||
parser.add_argument(
|
||||
"--param_path", type=str, default="./params", help="Model directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_path",
|
||||
type=str,
|
||||
default="./humaneval/HumanEval.jsonl",
|
||||
help="HumanEval JSONL file (auto-download if missing)",
|
||||
)
|
||||
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Completions per problem",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens", type=int, default=512, help="Max generation tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.8, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Inference batch size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--problems",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Specific problem indices (0-based)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
_download_humaneval(args.data_path)
|
||||
problems = _load_problems(args.data_path)
|
||||
if args.problems:
|
||||
problems = [problems[i] for i in args.problems if i < len(problems)]
|
||||
|
||||
model = AutoModel.from_pretrained(args.param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
results = evaluate(
|
||||
engine=engine,
|
||||
problems=problems,
|
||||
num_samples=args.num_samples,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
batch_size=args.batch_size,
|
||||
k_values=(1, 10, 100),
|
||||
)
|
||||
|
||||
summary = results.pop("_summary")
|
||||
print(f"\n{'=' * 60}")
|
||||
for k, v in summary.items():
|
||||
print(f" {k}: {v:.2%}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if args.output:
|
||||
results["_summary"] = summary
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"Results saved to {args.output}")
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,319 +0,0 @@
|
|||
"""MMLU evaluation via log-likelihood ranking."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
|
||||
MMLU_SUBJECTS = [
|
||||
"abstract_algebra",
|
||||
"anatomy",
|
||||
"astronomy",
|
||||
"business_ethics",
|
||||
"clinical_knowledge",
|
||||
"college_biology",
|
||||
"college_chemistry",
|
||||
"college_computer_science",
|
||||
"college_mathematics",
|
||||
"college_medicine",
|
||||
"college_physics",
|
||||
"computer_security",
|
||||
"conceptual_physics",
|
||||
"econometrics",
|
||||
"electrical_engineering",
|
||||
"elementary_mathematics",
|
||||
"formal_logic",
|
||||
"global_facts",
|
||||
"high_school_biology",
|
||||
"high_school_chemistry",
|
||||
"high_school_computer_science",
|
||||
"high_school_european_history",
|
||||
"high_school_geography",
|
||||
"high_school_government_and_politics",
|
||||
"high_school_macroeconomics",
|
||||
"high_school_mathematics",
|
||||
"high_school_microeconomics",
|
||||
"high_school_physics",
|
||||
"high_school_psychology",
|
||||
"high_school_statistics",
|
||||
"high_school_us_history",
|
||||
"high_school_world_history",
|
||||
"human_aging",
|
||||
"human_sexuality",
|
||||
"international_law",
|
||||
"jurisprudence",
|
||||
"logical_fallacies",
|
||||
"machine_learning",
|
||||
"management",
|
||||
"marketing",
|
||||
"medical_genetics",
|
||||
"miscellaneous",
|
||||
"moral_disputes",
|
||||
"moral_scenarios",
|
||||
"nutrition",
|
||||
"philosophy",
|
||||
"prehistory",
|
||||
"professional_accounting",
|
||||
"professional_law",
|
||||
"professional_medicine",
|
||||
"professional_psychology",
|
||||
"public_relations",
|
||||
"security_studies",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
"virology",
|
||||
"world_religions",
|
||||
]
|
||||
|
||||
|
||||
def _download_and_extract(url: str, data_dir: str):
|
||||
tar_path = os.path.join(data_dir, "data.tar")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
print(f"Downloading MMLU data from {url}...")
|
||||
resp = requests.get(url, stream=True, timeout=300)
|
||||
resp.raise_for_status()
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
|
||||
with open(tar_path, "wb") as f:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
print("Extracting...")
|
||||
with tarfile.open(tar_path, "r") as tf:
|
||||
tf.extractall(data_dir)
|
||||
os.remove(tar_path)
|
||||
|
||||
|
||||
def download_mmlu(data_dir: str):
|
||||
_download_and_extract(MMLU_URL, data_dir)
|
||||
src = os.path.join(data_dir, "data")
|
||||
if os.path.exists(src):
|
||||
for item in os.listdir(src):
|
||||
src_item = os.path.join(src, item)
|
||||
dst_item = os.path.join(data_dir, item)
|
||||
if os.path.exists(dst_item):
|
||||
if os.path.isdir(dst_item):
|
||||
shutil.rmtree(dst_item)
|
||||
else:
|
||||
os.remove(dst_item)
|
||||
os.rename(src_item, dst_item)
|
||||
os.rmdir(src)
|
||||
print(f"MMLU data saved to {data_dir}")
|
||||
|
||||
|
||||
def _strip_prefix(text: str, prefix: str) -> str:
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix) :].strip()
|
||||
return text
|
||||
|
||||
|
||||
def load_csv(path: str) -> list[dict]:
|
||||
data = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for row in csv.reader(f):
|
||||
if len(row) < 6:
|
||||
continue
|
||||
if row[0].strip().lower() == "question":
|
||||
continue
|
||||
data.append(
|
||||
{
|
||||
"question": row[0].strip(),
|
||||
"A": _strip_prefix(row[1].strip(), "A)"),
|
||||
"B": _strip_prefix(row[2].strip(), "B)"),
|
||||
"C": _strip_prefix(row[3].strip(), "C)"),
|
||||
"D": _strip_prefix(row[4].strip(), "D)"),
|
||||
"answer": row[5].strip(),
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def build_prompt(
|
||||
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
|
||||
) -> str:
|
||||
prompt = ""
|
||||
if n_shot > 0 and dev_data:
|
||||
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
|
||||
for item in dev_data[:n_shot]:
|
||||
prompt += f"Question: {item['question']}\n"
|
||||
for k in ("A", "B", "C", "D"):
|
||||
prompt += f"{k}. {item[k]}\n"
|
||||
prompt += f"Answer: {item['answer']}\n\n"
|
||||
prompt += f"Question: {question}\n"
|
||||
for k in ("A", "B", "C", "D"):
|
||||
prompt += f"{k}. {choices[k]}\n"
|
||||
prompt += "Answer:"
|
||||
return prompt
|
||||
|
||||
|
||||
def apply_chat(
|
||||
tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None
|
||||
) -> str:
|
||||
"""Wrap raw MMLU prompt in the model's chat template format.
|
||||
|
||||
For few-shot, prepend example Q&A pairs as a second user/assistant exchange.
|
||||
"""
|
||||
messages = []
|
||||
if n_shot > 0 and dev_data:
|
||||
for item in dev_data[:n_shot]:
|
||||
q = f"Question: {item['question']}\n"
|
||||
for k in ("A", "B", "C", "D"):
|
||||
q += f"{k}. {item[k]}\n"
|
||||
q += "Answer:"
|
||||
messages.append({"role": "user", "content": q})
|
||||
messages.append({"role": "assistant", "content": item["answer"]})
|
||||
messages.append({"role": "user", "content": raw_prompt})
|
||||
return tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
|
||||
def choice_logprob(
|
||||
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
|
||||
) -> float:
|
||||
choice_text = choice_letter
|
||||
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
|
||||
input_ids = context_ids + choice_ids
|
||||
max_len = model.config.max_len
|
||||
if len(input_ids) > max_len:
|
||||
overflow = len(input_ids) - max_len
|
||||
input_ids = input_ids[overflow:]
|
||||
ctx_len = len(input_ids) - len(choice_ids)
|
||||
else:
|
||||
ctx_len = len(context_ids)
|
||||
|
||||
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
|
||||
with torch.inference_mode():
|
||||
logits = model(input_tensor)["logits"][0]
|
||||
|
||||
score = 0.0
|
||||
for i, tid in enumerate(choice_ids):
|
||||
pos = ctx_len - 1 + i
|
||||
if pos >= len(logits):
|
||||
break
|
||||
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
|
||||
return score
|
||||
|
||||
|
||||
def evaluate_subject(
|
||||
model,
|
||||
tokenizer,
|
||||
subject: str,
|
||||
test_data: list[dict],
|
||||
dev_data: list[dict] | None,
|
||||
device: str,
|
||||
n_shot: int,
|
||||
) -> tuple[float, int, int]:
|
||||
correct = 0
|
||||
total = 0
|
||||
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
|
||||
raw_prompt = build_prompt(
|
||||
item["question"], item, subject, n_shot, dev_data or []
|
||||
)
|
||||
context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or [])
|
||||
context_ids = tokenizer.encode(context)
|
||||
scores = {
|
||||
c: choice_logprob(model, tokenizer, context_ids, c, device)
|
||||
for c in ("A", "B", "C", "D")
|
||||
}
|
||||
if max(scores, key=scores.get) == item["answer"]:
|
||||
correct += 1
|
||||
total += 1
|
||||
return correct / total, correct, total
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
||||
parser.add_argument(
|
||||
"--param_path", type=str, default="./params", help="Model directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
||||
)
|
||||
parser.add_argument("--download", action="store_true", help="Download MMLU data")
|
||||
parser.add_argument(
|
||||
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output JSON path")
|
||||
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16" if torch.cuda.is_available() else "float32",
|
||||
help="Torch dtype",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.download or not os.path.exists(args.data_dir):
|
||||
download_mmlu(args.data_dir)
|
||||
|
||||
model = AutoModel.from_pretrained(args.param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
||||
device = args.device
|
||||
dtype = getattr(torch, args.dtype)
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
subjects = args.subjects or MMLU_SUBJECTS
|
||||
results = {}
|
||||
total_correct = 0
|
||||
total_questions = 0
|
||||
|
||||
for subject in subjects:
|
||||
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
|
||||
test_path = os.path.join(
|
||||
args.data_dir, args.split, f"{subject}_{args.split}.csv"
|
||||
)
|
||||
|
||||
if not os.path.exists(test_path):
|
||||
print(f" Skipping {subject}: test file not found")
|
||||
continue
|
||||
|
||||
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
|
||||
test_data = load_csv(test_path)
|
||||
|
||||
acc, corr, tot = evaluate_subject(
|
||||
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
|
||||
)
|
||||
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
|
||||
total_correct += corr
|
||||
total_questions += tot
|
||||
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
|
||||
|
||||
overall = total_correct / total_questions if total_questions else 0
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
|
||||
results["_overall"] = {
|
||||
"accuracy": round(overall, 4),
|
||||
"correct": total_correct,
|
||||
"total": total_questions,
|
||||
}
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"Results saved to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
|
|||
|
||||
|
||||
def process_file(
|
||||
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||
):
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
model = AutoModel.from_pretrained(model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
|
|
@ -44,8 +44,8 @@ def process_file(
|
|||
|
||||
for seq in batch_encoded:
|
||||
pad_len = max_len - len(seq)
|
||||
padded_seq = seq + [tokenizer.pad_id] * pad_len
|
||||
mask = [True] * len(seq) + [False] * pad_len
|
||||
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
||||
mask = [False] * pad_len + [True] * len(seq)
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ def process_file(
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||
parser.add_argument(
|
||||
"--param_path", type=str, required=True, help="Path to the model directory."
|
||||
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_file", type=str, required=True, help="Path to the input file."
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
|
||||
|
||||
import argparse
|
||||
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
from astrai.preprocessing.pipeline import Pipeline
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
|
||||
)
|
||||
parser.add_argument(
|
||||
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
|
||||
)
|
||||
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--config", "-c", required=True, help="Path to pipeline config JSON"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_path",
|
||||
default="params",
|
||||
help="Path to tokenizer directory (default: params)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = PipelineConfig.from_json(args.config)
|
||||
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=args.inputs,
|
||||
output_dir=args.output_dir,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -18,7 +18,7 @@ def main():
|
|||
"--reload", action="store_true", help="Enable auto-reload for development"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--param_path",
|
||||
"--param-path",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Path to model parameters (default: project_root/params)",
|
||||
|
|
|
|||
|
|
@ -2,13 +2,16 @@ import argparse
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.parallel import get_rank
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
|
|
@ -116,12 +119,6 @@ def parse_args() -> argparse.Namespace:
|
|||
default=0.05,
|
||||
help="cross_entropy function label smoothing parameter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="Enable activation checkpointing for DecoderBlock modules.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ckpt_interval",
|
||||
|
|
@ -135,36 +132,6 @@ def parse_args() -> argparse.Namespace:
|
|||
default="checkpoint",
|
||||
help="Directory to save checkpoints.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_split",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Ratio to split from training dataset for validation (e.g. 0.05).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_step",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of optimizer steps between validation runs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metrics",
|
||||
nargs="*",
|
||||
default=["loss", "lr"],
|
||||
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_dir",
|
||||
type=str,
|
||||
default="checkpoint/logs",
|
||||
help="Directory for metric logs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_interval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of batch iterations between metric logs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpo_sync_interval",
|
||||
type=int,
|
||||
|
|
@ -178,32 +145,7 @@ def parse_args() -> argparse.Namespace:
|
|||
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master_addr",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Master node address for distributed training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_port",
|
||||
type=str,
|
||||
default="29500",
|
||||
help="Master node port for distributed training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="nccl",
|
||||
help="Distributed training backend.",
|
||||
)
|
||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||
parser.add_argument(
|
||||
"--parallel_mode",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["none", "ddp", "fsdp"],
|
||||
help="Parallel training strategy (none, ddp, fsdp).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||
)
|
||||
|
|
@ -220,11 +162,21 @@ def parse_args() -> argparse.Namespace:
|
|||
return args
|
||||
|
||||
|
||||
def create_model(config):
|
||||
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||
def ddp_wrap(model: nn.Module):
|
||||
local_rank = get_rank()
|
||||
ddp_model = DDP(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
static_graph=True,
|
||||
find_unused_parameters=False,
|
||||
gradient_as_bucket_view=True,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
return ddp_model
|
||||
|
||||
|
||||
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -234,6 +186,12 @@ def create_scheduler(
|
|||
return SchedulerFactory.create(optimizer, **kwargs)
|
||||
|
||||
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
if isinstance(model, DDP):
|
||||
return model.module.state_dict()
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
def compute_total_steps(
|
||||
dataset_len: int,
|
||||
n_epoch: int,
|
||||
|
|
@ -264,11 +222,6 @@ def train(
|
|||
warmup_ratio: float,
|
||||
ckpt_interval: int,
|
||||
ckpt_dir: str,
|
||||
val_split: float,
|
||||
val_step: int,
|
||||
metrics: list[str],
|
||||
log_dir: str,
|
||||
log_interval: int,
|
||||
dpo_beta: float,
|
||||
grpo_clip_eps: float,
|
||||
grpo_kl_coef: float,
|
||||
|
|
@ -282,21 +235,14 @@ def train(
|
|||
random_seed: int,
|
||||
num_workers: int,
|
||||
pin_memory: bool,
|
||||
gradient_checkpointing: bool,
|
||||
window_size: int,
|
||||
stride: int,
|
||||
nprocs: int,
|
||||
parallel_mode: str,
|
||||
device_type: str,
|
||||
backend: str,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
start_method: str,
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||
assert os.path.exists(param_path)
|
||||
if nprocs > 1 and parallel_mode == "none":
|
||||
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(param_path, "config.json")
|
||||
|
|
@ -305,6 +251,17 @@ def train(
|
|||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
||||
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||
model = AutoRegressiveLM(config)
|
||||
|
||||
# Load weights if available
|
||||
weights_path = os.path.join(param_path, "model.safetensors")
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = st.load_file(weights_path)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
|
||||
strategy_kwargs = {
|
||||
"beta": dpo_beta,
|
||||
"label_smoothing": label_smoothing,
|
||||
|
|
@ -314,12 +271,6 @@ def train(
|
|||
"sync_interval": grpo_sync_interval,
|
||||
}
|
||||
|
||||
executor_kwargs = {
|
||||
"gradient_as_bucket_view": True,
|
||||
"broadcast_buffers": False,
|
||||
}
|
||||
|
||||
model_fn = partial(create_model, config)
|
||||
dataset = DatasetFactory.load(
|
||||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
|
|
@ -350,10 +301,8 @@ def train(
|
|||
},
|
||||
)
|
||||
|
||||
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
|
||||
|
||||
train_config = TrainConfig(
|
||||
model_fn=model_fn,
|
||||
model=model,
|
||||
strategy=train_type,
|
||||
dataset=dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
|
|
@ -370,24 +319,15 @@ def train(
|
|||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
parallel_mode=parallel_mode,
|
||||
parallel_wrapper=ddp_wrap,
|
||||
state_dict_fn=prepare_checkpoint,
|
||||
device_type=device_type,
|
||||
start_method=start_method,
|
||||
val_split=val_split,
|
||||
val_step=val_step,
|
||||
metrics=metrics,
|
||||
log_dir=log_dir,
|
||||
log_interval=log_interval,
|
||||
gradient_checkpointing_modules=grad_ckpt_modules,
|
||||
executor_kwargs=executor_kwargs,
|
||||
extra_kwargs=strategy_kwargs,
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train(resume_dir=param_path)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,202 +0,0 @@
|
|||
import tempfile
|
||||
|
||||
import pytest
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
PipelineConfig,
|
||||
ProcessingConfig,
|
||||
)
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_SPECIAL_TOKENS_CONFIG = {
|
||||
"bos_token": "<|begin_of_sentence|>",
|
||||
"eos_token": "<|end_of_sentence|>",
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
"im_start": "<|im_start|>",
|
||||
"im_end": "<|im_end|>",
|
||||
}
|
||||
|
||||
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
||||
|
||||
_CHAT_TEMPLATE = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}"
|
||||
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'user' %}"
|
||||
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
)
|
||||
|
||||
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
||||
|
||||
_INSTRUCTION_SECTIONS = [
|
||||
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
||||
{"field": "response", "action": "train"},
|
||||
]
|
||||
|
||||
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
||||
|
||||
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
|
||||
|
||||
|
||||
def _build_chat_tokenizer():
|
||||
tok = Tokenizer(models.BPE())
|
||||
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
tr = trainers.BpeTrainer(
|
||||
vocab_size=512,
|
||||
min_frequency=1,
|
||||
special_tokens=_SPECIAL_TOKENS,
|
||||
)
|
||||
train_data = [
|
||||
"hello world",
|
||||
"Hi there!",
|
||||
"You are helpful.",
|
||||
"What is 2+2?",
|
||||
"Tell me a story about dragons and knights.",
|
||||
"Sure, here is a tale.",
|
||||
"Translate to French: Hello",
|
||||
"Bonjour",
|
||||
"Artificial Intelligence is a field of computer science.",
|
||||
"system",
|
||||
"user",
|
||||
"assistant",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
*[chr(i) for i in range(32, 127)],
|
||||
]
|
||||
tok.train_from_iterator(train_data, tr)
|
||||
|
||||
auto_tok = AutoTokenizer()
|
||||
auto_tok._tokenizer = tok
|
||||
auto_tok._special_token_map = {
|
||||
"bos_token": "<|begin_of_sentence|>",
|
||||
"eos_token": "<|end_of_sentence|>",
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
||||
return auto_tok
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def chat_tokenizer():
|
||||
return _build_chat_tokenizer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
d = tempfile.mkdtemp()
|
||||
yield d
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(d, ignore_errors=True)
|
||||
|
||||
|
||||
def make_chat_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
|
||||
|
||||
def make_instruction_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
|
||||
|
||||
def make_text_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(
|
||||
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_dpo_chat_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(
|
||||
sources={
|
||||
"chosen": {
|
||||
"sections": [
|
||||
{"field": "chosen", "action": "$role", "template": True}
|
||||
]
|
||||
},
|
||||
"rejected": {
|
||||
"sections": [
|
||||
{"field": "rejected", "action": "$role", "template": True}
|
||||
]
|
||||
},
|
||||
}
|
||||
),
|
||||
mask={"user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
|
||||
|
||||
def make_grpo_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(
|
||||
sources={
|
||||
"prompts": {
|
||||
"sections": [
|
||||
{"field": "prompt", "action": "mask", "template": True}
|
||||
]
|
||||
},
|
||||
"responses": {
|
||||
"sections": _GRPO_RESPONSE_SECTIONS,
|
||||
"list_field": True,
|
||||
"mask_key": "masks",
|
||||
},
|
||||
"rewards": {
|
||||
"sections": [{"field": "rewards", "action": "value"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
mask={"user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
|
||||
|
||||
def make_grpo_no_template_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(
|
||||
sources={
|
||||
"prompts": {
|
||||
"sections": [
|
||||
{
|
||||
"field": "prompt",
|
||||
"action": "mask",
|
||||
"add_special_tokens": True,
|
||||
}
|
||||
]
|
||||
},
|
||||
"responses": {
|
||||
"sections": _GRPO_RESPONSE_SECTIONS,
|
||||
"list_field": True,
|
||||
"mask_key": "masks",
|
||||
},
|
||||
"rewards": {
|
||||
"sections": [{"field": "rewards", "action": "value"}],
|
||||
},
|
||||
}
|
||||
),
|
||||
mask={"user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
|
@ -37,6 +36,7 @@ def test_single_process():
|
|||
|
||||
|
||||
def test_checkpoint_with_extra():
|
||||
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
|
|
@ -52,6 +52,8 @@ def test_checkpoint_with_extra():
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
|
||||
import os
|
||||
|
||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -6,11 +7,12 @@ import torch
|
|||
|
||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||
from astrai.dataset.storage import (
|
||||
H5Store,
|
||||
StoreFactory,
|
||||
BaseSegmentFetcher,
|
||||
H5Storage,
|
||||
MultiSegmentFetcher,
|
||||
StorageFactory,
|
||||
detect_format,
|
||||
load_bin,
|
||||
save_bin,
|
||||
load_json,
|
||||
save_h5,
|
||||
)
|
||||
|
||||
|
|
@ -98,7 +100,6 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
dummy_data = {
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
|
||||
}
|
||||
|
||||
save_h5(test_dir, "sft_data", dummy_data)
|
||||
|
|
@ -156,6 +157,111 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
assert len(dataset) > len(default_stride_dataset)
|
||||
|
||||
|
||||
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
||||
|
||||
|
||||
def _make_tokenizer_fn(tokenizer):
|
||||
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
||||
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
|
||||
def test_seq_dataset_from_json_text(base_test_env):
|
||||
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_text")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = [
|
||||
"hello world this is a test sentence for tokenizer",
|
||||
"another sentence with different words and tokens",
|
||||
"machine learning is fascinating and powerful",
|
||||
]
|
||||
|
||||
json_path = os.path.join(data_dir, "seq_data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=data_dir,
|
||||
window_size=16,
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count > 0
|
||||
assert "sequence" in dataset.keys
|
||||
|
||||
item = dataset[0]
|
||||
assert "input_ids" in item
|
||||
assert "target_ids" in item
|
||||
assert item["input_ids"].shape[0] == 16
|
||||
|
||||
|
||||
def test_sft_dataset_from_json_text(base_test_env):
|
||||
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_sft")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = [
|
||||
"user asks a question about the weather",
|
||||
"assistant provides a helpful response to the user",
|
||||
]
|
||||
|
||||
json_path = os.path.join(data_dir, "sft_data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"sequence": texts, "loss_mask": texts},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="sft",
|
||||
load_path=data_dir,
|
||||
window_size=16,
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
|
||||
item = dataset[0]
|
||||
assert "loss_mask" in item
|
||||
|
||||
|
||||
def test_json_storage_explicit_tokenizer(base_test_env):
|
||||
"""Test explicit JSON storage with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_explicit")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
||||
|
||||
json_path = os.path.join(data_dir, "data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||
|
||||
token_count = len(tokenizer_fn(texts[0]))
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=data_dir,
|
||||
window_size=32,
|
||||
storage_type="json",
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == token_count
|
||||
|
||||
|
||||
def test_dataset_count_property(base_test_env):
|
||||
"""Test the count property returns correct raw token count"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
|
@ -212,29 +318,37 @@ def test_unloaded_dataset_len():
|
|||
assert len(dataset) == 0
|
||||
|
||||
|
||||
def test_store_unloaded_len():
|
||||
"""Unloaded Store has __len__ == 0"""
|
||||
store = H5Store()
|
||||
assert len(store) == 0
|
||||
assert store.keys == []
|
||||
def test_base_segment_fetcher_empty():
|
||||
"""BaseSegmentFetcher with empty segments list"""
|
||||
fetcher = BaseSegmentFetcher([])
|
||||
assert len(fetcher) == 0
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
fetcher.fetch_data(0, 1)
|
||||
|
||||
|
||||
def test_store_fetch_begin_equals_end(base_test_env):
|
||||
"""Store.fetch with begin == end returns empty tensor"""
|
||||
def test_base_segment_fetcher_begin_equals_end(base_test_env):
|
||||
"""fetch_data with begin == end returns empty tensor"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
||||
save_h5(test_dir, "empty_fetch", dummy)
|
||||
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
||||
result = dataset.storage.fetch(10, 10, "sequence")
|
||||
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
|
||||
result = fetcher.fetch_data(10, 10)
|
||||
assert result.numel() == 0
|
||||
|
||||
|
||||
def test_store_fetch_before_load():
|
||||
"""Store.fetch before load raises RuntimeError"""
|
||||
store = H5Store()
|
||||
def test_multi_segment_fetcher_empty_dict():
|
||||
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
|
||||
fetcher = MultiSegmentFetcher({})
|
||||
assert len(fetcher) == 0
|
||||
|
||||
|
||||
def test_storage_fetch_before_load():
|
||||
"""BaseStorage.fetch before load raises RuntimeError"""
|
||||
storage = H5Storage()
|
||||
with pytest.raises(RuntimeError, match="not loaded"):
|
||||
store.fetch(0, 10, "sequence")
|
||||
storage.fetch(0, 10, "sequence")
|
||||
|
||||
|
||||
def test_detect_format_nonexistent_path():
|
||||
|
|
@ -253,192 +367,54 @@ def test_detect_format_unsupported_file(base_test_env):
|
|||
detect_format(path)
|
||||
|
||||
|
||||
def test_create_store_invalid_type():
|
||||
"""StoreFactory.create raises ValueError for unknown type"""
|
||||
def test_create_storage_invalid_type():
|
||||
"""StorageFactory.create raises ValueError for unknown type"""
|
||||
with pytest.raises(ValueError, match="Unknown component"):
|
||||
StoreFactory.create("parquet")
|
||||
StorageFactory.create("parquet")
|
||||
|
||||
|
||||
def test_store_multi_segment_concat(base_test_env):
|
||||
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
||||
import os
|
||||
|
||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "multi_seg")
|
||||
data_dir = os.path.join(test_dir, "json_pretok")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
json_path = os.path.join(data_dir, "data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
||||
|
||||
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == 10
|
||||
|
||||
item = dataset[0]
|
||||
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
||||
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_load_json_skips_config_file(base_test_env):
|
||||
"""load_json skips scalar-value config files"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
||||
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
||||
|
||||
with open(os.path.join(test_dir, "data.json"), "w") as f:
|
||||
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
||||
|
||||
result = load_json(test_dir)
|
||||
assert "sequence" in result
|
||||
assert "vocab_size" not in result
|
||||
assert len(result["sequence"]) == 1
|
||||
|
||||
|
||||
def test_base_segment_fetcher_multi_segment():
|
||||
"""fetch_data across multiple segment boundaries"""
|
||||
segs = [
|
||||
torch.tensor([1, 2, 3]),
|
||||
torch.tensor([4, 5, 6, 7]),
|
||||
torch.tensor([8, 9]),
|
||||
]
|
||||
save_h5(data_dir, "data", {"sequence": segs})
|
||||
|
||||
store = StoreFactory.create("h5")
|
||||
store.load(data_dir)
|
||||
assert len(store) == 9
|
||||
result = store.fetch(2, 7, "sequence")
|
||||
fetcher = BaseSegmentFetcher(segs)
|
||||
assert len(fetcher) == 9
|
||||
result = fetcher.fetch_data(2, 7)
|
||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||
|
||||
|
||||
def test_save_load_bin_roundtrip(base_test_env):
|
||||
"""save_bin + load_bin roundtrip preserves data"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
data = {
|
||||
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
|
||||
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
|
||||
}
|
||||
save_bin(test_dir, data)
|
||||
result = load_bin(test_dir)
|
||||
|
||||
assert "sequence" in result
|
||||
assert "loss_mask" in result
|
||||
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
|
||||
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
|
||||
|
||||
|
||||
def test_mmap_store_load_and_fetch(base_test_env):
|
||||
"""MmapStore loads bin data and fetches correctly"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
data = {
|
||||
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||
}
|
||||
save_bin(test_dir, data)
|
||||
|
||||
store = StoreFactory.create("bin")
|
||||
store.load(test_dir)
|
||||
assert len(store) == 200
|
||||
assert "sequence" in store.keys
|
||||
|
||||
result = store.fetch(10, 20, "sequence")
|
||||
assert result.tolist() == data["sequence"][0][10:20].tolist()
|
||||
|
||||
|
||||
def test_mmap_dataset_load(base_test_env):
|
||||
"""DatasetFactory.load auto-detects bin format"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
data = {
|
||||
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||
}
|
||||
save_bin(test_dir, data)
|
||||
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == 200
|
||||
assert dataset[0]["input_ids"].shape[0] == 64
|
||||
|
||||
|
||||
def test_normalize_empty_key():
|
||||
"""_normalize with empty tensor list does not crash"""
|
||||
store = H5Store()
|
||||
store._normalize({"sequence": []})
|
||||
assert len(store) == 0
|
||||
assert store.keys == ["sequence"]
|
||||
|
||||
|
||||
def test_normalize_mixed_empty_key():
|
||||
"""_normalize with empty + non-empty keys returns min=0"""
|
||||
store = H5Store()
|
||||
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
|
||||
assert len(store) == 0
|
||||
assert set(store.keys) == {"sequence", "loss_mask"}
|
||||
|
||||
|
||||
def test_grpo_dataset_dtype(base_test_env):
|
||||
"""GRPODataset returns correct dtypes"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
seq_len = 100
|
||||
data = {
|
||||
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||
"masks": [torch.ones(seq_len, dtype=torch.int32)],
|
||||
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
|
||||
}
|
||||
save_h5(test_dir, "grpo_dtype", data)
|
||||
|
||||
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
|
||||
item = dataset[0]
|
||||
|
||||
assert item["prompts"].dtype == torch.long
|
||||
assert item["responses"].dtype == torch.long
|
||||
assert item["masks"].dtype == torch.bool
|
||||
assert item["rewards"].dtype == torch.float32
|
||||
|
||||
|
||||
def test_grpo_dataset_load(base_test_env):
|
||||
"""GRPODataset loads and returns correct keys"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
seq_len = 200
|
||||
data = {
|
||||
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||
"masks": [torch.ones(seq_len, dtype=torch.int64)],
|
||||
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
|
||||
}
|
||||
save_h5(test_dir, "grpo_test", data)
|
||||
|
||||
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
|
||||
assert len(dataset) > 0
|
||||
item = dataset[0]
|
||||
assert "prompts" in item
|
||||
assert "responses" in item
|
||||
assert "masks" in item
|
||||
assert "rewards" in item
|
||||
assert item["prompts"].shape[0] == 64
|
||||
assert item["responses"].shape[0] == 64
|
||||
|
||||
|
||||
def test_detect_format_bin_dir(base_test_env):
|
||||
"""detect_format returns 'bin' for directory with .bin + meta.json"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
|
||||
assert detect_format(test_dir) == "bin"
|
||||
|
||||
|
||||
def test_store_fetch_multi_key(base_test_env):
|
||||
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
save_h5(
|
||||
test_dir,
|
||||
"multi_key",
|
||||
{
|
||||
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
|
||||
"loss_mask": [torch.ones(100, dtype=torch.int64)],
|
||||
},
|
||||
)
|
||||
|
||||
store = StoreFactory.create("h5")
|
||||
store.load(test_dir)
|
||||
result = store.fetch(10, 20, ["sequence", "loss_mask"])
|
||||
assert isinstance(result, dict)
|
||||
assert result["sequence"].shape[0] == 10
|
||||
assert result["loss_mask"].shape[0] == 10
|
||||
|
||||
|
||||
def test_store_fetch_out_of_bounds(base_test_env):
|
||||
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
|
||||
|
||||
store = StoreFactory.create("h5")
|
||||
store.load(test_dir)
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
store.fetch(-1, 10, "sequence")
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
store.fetch(0, 51, "sequence")
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
store.fetch(50, 50, "sequence")
|
||||
|
||||
|
||||
def test_dataset_load_explicit_storage_type(base_test_env):
|
||||
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
|
||||
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == 200
|
||||
|
|
|
|||
|
|
@ -1,396 +0,0 @@
|
|||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
OutputConfig,
|
||||
PipelineConfig,
|
||||
ProcessingConfig,
|
||||
)
|
||||
from astrai.preprocessing.builder import (
|
||||
MaskBuilderFactory,
|
||||
SectionedMaskBuilder,
|
||||
)
|
||||
from tests.data.conftest import (
|
||||
_CHAT_SECTIONS,
|
||||
_INSTRUCTION_SECTIONS,
|
||||
_TEXT_SECTIONS,
|
||||
make_chat_config,
|
||||
make_dpo_chat_config,
|
||||
make_grpo_config,
|
||||
make_instruction_config,
|
||||
make_text_config,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_simple(chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello."},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result is not None
|
||||
assert "sequence" in result
|
||||
assert "loss_mask" in result
|
||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||
|
||||
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
|
||||
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
||||
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
||||
|
||||
total = len(result["sequence"])
|
||||
trained = sum(result["loss_mask"])
|
||||
assert trained > 0
|
||||
assert trained < total
|
||||
|
||||
|
||||
def test_chat_mask_only_assistant(chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
ids = result["sequence"]
|
||||
assert len(ids) == len(mask)
|
||||
|
||||
trained = [i for i, m in enumerate(mask) if m == 1]
|
||||
masked = [i for i, m in enumerate(mask) if m == 0]
|
||||
assert len(trained) > 0
|
||||
assert len(masked) > 0
|
||||
|
||||
|
||||
def test_chat_all_masked(chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert sum(result["loss_mask"]) == 0
|
||||
|
||||
|
||||
def test_chat_all_trained(chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={},
|
||||
mask_default="train",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
|
||||
|
||||
|
||||
def test_chat_empty_messages(chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
||||
assert builder.build({}, config, chat_tokenizer) is None
|
||||
|
||||
|
||||
def test_chat_domain_extraction(chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(domain_key="source"),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
],
|
||||
"source": "wiki",
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result["domain"] == "wiki"
|
||||
|
||||
|
||||
def test_chat_truncation(chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=10),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a very long story about dragons and knights and magic.",
|
||||
},
|
||||
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert len(result["sequence"]) <= 10
|
||||
assert len(result["loss_mask"]) == len(result["sequence"])
|
||||
|
||||
|
||||
def test_instruction_basic(test_tokenizer):
|
||||
config = make_instruction_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||
|
||||
|
||||
def test_instruction_prompt_masked(test_tokenizer):
|
||||
config = make_instruction_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "hello", "response": "world"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
ids = result["sequence"]
|
||||
|
||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||
p_len = min(len(prompt_ids), len(ids))
|
||||
assert all(m == 0 for m in mask[:p_len])
|
||||
if p_len < len(ids):
|
||||
assert all(m == 1 for m in mask[p_len:])
|
||||
|
||||
|
||||
def test_instruction_train_on_prompt(test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(
|
||||
sections=[
|
||||
{"field": "prompt", "action": "train", "add_special_tokens": True},
|
||||
{"field": "response", "action": "mask"},
|
||||
]
|
||||
),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "hello", "response": "world"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
ids = result["sequence"]
|
||||
|
||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||
p_len = min(len(prompt_ids), len(ids))
|
||||
assert all(m == 1 for m in mask[:p_len])
|
||||
|
||||
|
||||
def test_text_basic(test_tokenizer):
|
||||
config = make_text_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"text": "Hello world. This is a test document."}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
assert "sequence" in result
|
||||
assert len(result["sequence"]) > 0
|
||||
assert "loss_mask" not in result
|
||||
|
||||
|
||||
def test_text_empty(test_tokenizer):
|
||||
config = make_text_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
||||
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
||||
|
||||
|
||||
def test_text_too_short(test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(min_chars=100),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||
|
||||
|
||||
def test_text_truncation(test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"text": "This is a very long text that should be truncated"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert len(result["sequence"]) <= 3
|
||||
|
||||
|
||||
def test_sectioned_chat(chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result is not None
|
||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||
assert sum(result["loss_mask"]) > 0
|
||||
assert 0 in result["loss_mask"]
|
||||
|
||||
|
||||
def test_sectioned_instruction(test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
mask = result["loss_mask"]
|
||||
assert mask[0] == 0
|
||||
assert mask[-1] == 1
|
||||
|
||||
|
||||
def test_sectioned_text(test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"text": "Hello world, this is a test."}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
assert "loss_mask" not in result
|
||||
|
||||
|
||||
def test_sectioned_text_too_short(test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||
|
||||
|
||||
def test_factory_registered():
|
||||
names = MaskBuilderFactory._registry.list_names()
|
||||
assert "sectioned" in names
|
||||
|
||||
|
||||
def test_factory_create():
|
||||
builder = MaskBuilderFactory.create("sectioned")
|
||||
assert isinstance(builder, SectionedMaskBuilder)
|
||||
|
||||
|
||||
def test_dpo_chat_basic(chat_tokenizer):
|
||||
config = make_dpo_chat_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"chosen": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "5"},
|
||||
],
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result is not None
|
||||
assert "chosen" in result
|
||||
assert "rejected" in result
|
||||
assert "chosen_mask" in result
|
||||
assert "rejected_mask" in result
|
||||
assert "domain" in result
|
||||
assert len(result["chosen"]) == len(result["chosen_mask"])
|
||||
assert len(result["rejected"]) == len(result["rejected_mask"])
|
||||
assert sum(result["chosen_mask"]) > 0
|
||||
assert sum(result["rejected_mask"]) > 0
|
||||
|
||||
|
||||
def test_dpo_chosen_only_trained(chat_tokenizer):
|
||||
config = make_dpo_chat_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"chosen": [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Go away"},
|
||||
],
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert 0 in result["chosen_mask"]
|
||||
assert 1 in result["chosen_mask"]
|
||||
assert 0 in result["rejected_mask"]
|
||||
assert 1 in result["rejected_mask"]
|
||||
|
||||
|
||||
def test_dpo_missing_field_is_none(chat_tokenizer):
|
||||
config = make_dpo_chat_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
|
||||
|
||||
|
||||
def test_grpo_basic(chat_tokenizer):
|
||||
config = make_grpo_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"prompt": [{"role": "user", "content": "What is 2+2?"}],
|
||||
"responses": ["4", "The answer is four", "Four", "2+2=4"],
|
||||
"rewards": [1.0, 0.5, 0.8, 0.2],
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result is not None
|
||||
assert "prompts" in result
|
||||
assert "responses" in result
|
||||
assert "masks" in result
|
||||
assert "rewards" in result
|
||||
assert len(result["responses"]) == len(result["masks"])
|
||||
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
|
||||
|
||||
|
||||
def test_grpo_response_tokens_all_trained(chat_tokenizer):
|
||||
config = make_grpo_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"prompt": [{"role": "user", "content": "Q"}],
|
||||
"responses": ["A", "B"],
|
||||
"rewards": [0.8, 0.2],
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
masks = result["masks"]
|
||||
assert all(m == 1 for m in masks)
|
||||
assert len(masks) == len(result["responses"])
|
||||
|
||||
|
||||
def test_grpo_single_reward(chat_tokenizer):
|
||||
config = make_grpo_config()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"prompt": [{"role": "user", "content": "Q"}],
|
||||
"responses": ["A"],
|
||||
"rewards": 0.9,
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result["rewards"] == [0.9]
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
import os
|
||||
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
PipelineConfig,
|
||||
)
|
||||
from tests.data.conftest import (
|
||||
_INSTRUCTION_SECTIONS,
|
||||
_TEXT_SECTIONS,
|
||||
make_dpo_chat_config,
|
||||
)
|
||||
|
||||
|
||||
def test_default_values():
|
||||
config = PipelineConfig()
|
||||
assert config.version == 1
|
||||
assert config.mask == {}
|
||||
assert config.mask_default == "mask"
|
||||
assert config.preprocessing.max_seq_len == 2048
|
||||
assert config.output.storage_format == "bin"
|
||||
assert config.input.sections is None
|
||||
|
||||
|
||||
def test_from_dict_flat():
|
||||
data = {
|
||||
"version": 1,
|
||||
"input": {
|
||||
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
||||
},
|
||||
"mask": {"system": "mask", "assistant": "train"},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {"max_seq_len": 1024},
|
||||
"output": {"storage_format": "h5"},
|
||||
}
|
||||
config = PipelineConfig.from_dict(data)
|
||||
assert config.input.sections == [
|
||||
{"field": "messages", "action": "$role", "template": True}
|
||||
]
|
||||
assert config.mask == {"system": "mask", "assistant": "train"}
|
||||
assert config.preprocessing.max_seq_len == 1024
|
||||
assert config.output.storage_format == "h5"
|
||||
|
||||
|
||||
def test_to_dict_roundtrip():
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
)
|
||||
d = config.to_dict()
|
||||
config2 = PipelineConfig.from_dict(d)
|
||||
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||
|
||||
|
||||
def test_to_json_from_json(temp_dir):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
mask={"text": "train"},
|
||||
mask_default="mask",
|
||||
)
|
||||
path = os.path.join(temp_dir, "config.json")
|
||||
config.to_json(path)
|
||||
loaded = PipelineConfig.from_json(path)
|
||||
assert loaded.input.sections == _TEXT_SECTIONS
|
||||
assert loaded.mask == {"text": "train"}
|
||||
|
||||
|
||||
def test_dpo_config_roundtrip(temp_dir):
|
||||
config = make_dpo_chat_config()
|
||||
path = os.path.join(temp_dir, "config.json")
|
||||
config.to_json(path)
|
||||
loaded = PipelineConfig.from_json(path)
|
||||
assert loaded.input.sources is not None
|
||||
assert "chosen" in loaded.input.sources
|
||||
assert "rejected" in loaded.input.sources
|
||||
assert loaded.input.sections is None
|
||||
|
|
@ -1,349 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
OutputConfig,
|
||||
PipelineConfig,
|
||||
ProcessingConfig,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||
from tests.data.conftest import (
|
||||
_CHAT_SECTIONS,
|
||||
_CHAT_TEMPLATE,
|
||||
_INSTRUCTION_SECTIONS,
|
||||
_SPECIAL_TOKENS_CONFIG,
|
||||
_TEXT_SECTIONS,
|
||||
make_dpo_chat_config,
|
||||
make_grpo_no_template_config,
|
||||
)
|
||||
|
||||
|
||||
def test_filter_by_length():
|
||||
assert filter_by_length("hello world", min_len=5)
|
||||
assert not filter_by_length("hi", min_len=5)
|
||||
assert not filter_by_length("x" * 100, max_len=50)
|
||||
assert filter_by_length("just right", min_len=5, max_len=20)
|
||||
|
||||
|
||||
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||
"chat_template": _CHAT_TEMPLATE,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi."},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(storage_format="bin", domain_key=None),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" in meta
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
assert meta["loss_mask"]["dtype"] == "int32"
|
||||
|
||||
|
||||
def test_full_text_pipeline(temp_dir, test_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": {
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"text": "Another document for testing purposes with sufficient length to be processed."
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
|
||||
output=OutputConfig(storage_format="bin"),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" not in meta
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
|
||||
|
||||
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": {
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": "Tell me a joke",
|
||||
"response": "Why did the chicken cross the road?",
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": "What is AI?",
|
||||
"response": "Artificial Intelligence is a field of computer science.",
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(storage_format="bin"),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" in meta
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
assert meta["loss_mask"]["dtype"] == "int32"
|
||||
|
||||
|
||||
def test_dtype_override(temp_dir, test_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": {
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "data.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
assert meta["loss_mask"]["dtype"] == "bool"
|
||||
|
||||
|
||||
def test_dpo_pipeline(temp_dir, chat_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||
"chat_template": _CHAT_TEMPLATE,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "Hi."},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "Hi."},
|
||||
{"role": "assistant", "content": "Go away."},
|
||||
],
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=make_dpo_chat_config(),
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "chosen" in meta
|
||||
assert "rejected" in meta
|
||||
assert "chosen_mask" in meta
|
||||
assert "rejected_mask" in meta
|
||||
assert "sequence" not in meta
|
||||
|
||||
|
||||
def test_grpo_pipeline(temp_dir, test_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": {
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": "Question?",
|
||||
"responses": ["Answer A", "Answer B"],
|
||||
"rewards": [0.8, 0.3],
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=make_grpo_no_template_config(),
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "prompts" in meta
|
||||
assert "responses" in meta
|
||||
assert "masks" in meta
|
||||
assert "rewards" in meta
|
||||
assert "sequence" not in meta
|
||||
|
|
@ -5,22 +5,21 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from astrai.inference import get_app
|
||||
from astrai.inference import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Provide a test client for the FastAPI app."""
|
||||
_app = get_app()
|
||||
_app.state.server_config = {
|
||||
app.state.server_config = {
|
||||
"device": "cpu",
|
||||
"dtype": "bfloat16",
|
||||
"param_path": None,
|
||||
"max_batch_size": 1,
|
||||
"_test": True,
|
||||
}
|
||||
_app.state.engine = None
|
||||
return TestClient(_app)
|
||||
app.state.engine = None
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -50,5 +49,5 @@ def mock_engine():
|
|||
@pytest.fixture
|
||||
def loaded_model(client, mock_engine):
|
||||
"""Simulate that the engine is loaded."""
|
||||
get_app().state.engine = mock_engine
|
||||
app.state.engine = mock_engine
|
||||
return mock_engine
|
||||
|
|
|
|||
|
|
@ -1,286 +0,0 @@
|
|||
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
|
||||
from astrai.inference.engine import GenerationRequest
|
||||
|
||||
|
||||
def _make_ctx(**kwargs):
|
||||
defaults = {
|
||||
"resp_id": "test-123",
|
||||
"created": 1000,
|
||||
"model": "test-model",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return GenContext(**defaults)
|
||||
|
||||
|
||||
def _sse_payloads(events):
|
||||
payloads = []
|
||||
for chunk in events:
|
||||
for line in chunk.strip().split("\n"):
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
payloads.append(json.loads(line[6:]))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return payloads
|
||||
|
||||
|
||||
class TestStopChecker:
|
||||
def test_check_finds_match(self):
|
||||
sc = StopChecker(["stop", "end"])
|
||||
assert sc.check("hello stop world") == "stop"
|
||||
|
||||
def test_check_returns_none_when_no_match(self):
|
||||
sc = StopChecker(["stop"])
|
||||
assert sc.check("hello world") is None
|
||||
|
||||
def test_check_empty_sequences(self):
|
||||
sc = StopChecker([])
|
||||
assert sc.check("hello") is None
|
||||
|
||||
|
||||
class TestGenContext:
|
||||
def test_defaults(self):
|
||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||
assert ctx.completion_tokens == 0
|
||||
|
||||
def test_fields_mutable(self):
|
||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
||||
ctx.completion_tokens = 42
|
||||
assert ctx.completion_tokens == 42
|
||||
|
||||
|
||||
class TestStopInfo:
|
||||
def test_defaults(self):
|
||||
s = StopInfo()
|
||||
assert s.matched is None
|
||||
assert s.body == ""
|
||||
assert s.yielded == ""
|
||||
|
||||
def test_with_values(self):
|
||||
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
|
||||
assert s.matched == "stop"
|
||||
assert s.body == "hello stop"
|
||||
assert s.yielded == "hello "
|
||||
|
||||
|
||||
class TestOpenAIResponseBuilder:
|
||||
@pytest.fixture
|
||||
def builder(self):
|
||||
builder = OpenAIResponseBuilder()
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hello")]
|
||||
req.stop = None
|
||||
req.model = "astrai"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||
builder.prepare(req, engine)
|
||||
return builder
|
||||
|
||||
def test_prepare_returns_prompt_ctx_stops(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hi")]
|
||||
req.stop = ["END"]
|
||||
req.model = "gpt"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||
prompt, ctx, stops = builder.prepare(req, engine)
|
||||
assert prompt == "Hi"
|
||||
assert ctx.model == "gpt"
|
||||
assert ctx.prompt_tokens == 0
|
||||
assert stops == ["END"]
|
||||
|
||||
def test_prepare_no_stop_returns_empty_list(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = []
|
||||
req.stop = None
|
||||
req.model = "x"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = ""
|
||||
_, _, stops = builder.prepare(req, engine)
|
||||
assert stops == []
|
||||
|
||||
def test_format_stream_start(self, builder):
|
||||
ctx = _make_ctx()
|
||||
events = builder.format_stream_start(ctx)
|
||||
payloads = _sse_payloads(events)
|
||||
assert len(payloads) == 1
|
||||
p = payloads[0]
|
||||
assert p["object"] == "chat.completion.chunk"
|
||||
assert p["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert p["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
event = builder.format_chunk("hello")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||
assert payload["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_stream_end(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=5)
|
||||
stop = StopInfo(matched="stop")
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
finish = payloads[0]
|
||||
assert finish["choices"][0]["finish_reason"] == "stop"
|
||||
usage = payloads[1]
|
||||
assert usage["completion_tokens"] == 5
|
||||
assert usage["total_tokens"] == 15
|
||||
|
||||
def test_format_response(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo()
|
||||
resp = builder.format_response(ctx, "hello", stop)
|
||||
assert resp["object"] == "chat.completion"
|
||||
assert resp["choices"][0]["message"]["content"] == "hello"
|
||||
assert resp["usage"]["prompt_tokens"] == 10
|
||||
|
||||
|
||||
class TestAnthropicResponseBuilder:
|
||||
@pytest.fixture
|
||||
def builder(self):
|
||||
builder = AnthropicResponseBuilder()
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hello")]
|
||||
req.model = "claude"
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
||||
req.system = None
|
||||
builder.prepare(req, engine)
|
||||
return builder
|
||||
|
||||
def test_prepare_messages(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = [MagicMock(role="user", content="Hi")]
|
||||
req.model = "claude"
|
||||
req.system = None
|
||||
req.stop_sequences = None
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
||||
prompt, ctx, stops = builder.prepare(req, engine)
|
||||
assert prompt == "Hi"
|
||||
assert stops == []
|
||||
|
||||
def test_prepare_with_stop_sequences(self, builder):
|
||||
req = MagicMock()
|
||||
req.messages = []
|
||||
req.model = "x"
|
||||
req.stop_sequences = ["stop", "end"]
|
||||
req.system = None
|
||||
engine = MagicMock()
|
||||
engine.tokenizer.apply_chat_template.return_value = ""
|
||||
_, _, stops = builder.prepare(req, engine)
|
||||
assert stops == ["stop", "end"]
|
||||
|
||||
def test_format_stream_start(self, builder):
|
||||
ctx = _make_ctx(prompt_tokens=3)
|
||||
events = builder.format_stream_start(ctx)
|
||||
payloads = _sse_payloads(events)
|
||||
assert len(payloads) == 2
|
||||
assert payloads[0]["type"] == "message_start"
|
||||
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
|
||||
assert payloads[1]["type"] == "content_block_start"
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
event = builder.format_chunk("tok")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
assert payload["type"] == "content_block_delta"
|
||||
assert payload["delta"]["text"] == "tok"
|
||||
|
||||
def test_format_stream_end_no_stop(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=3)
|
||||
stop = StopInfo()
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# content_block_stop, message_delta, message_stop
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
|
||||
|
||||
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
|
||||
ctx = _make_ctx(completion_tokens=7)
|
||||
stop = StopInfo(
|
||||
matched="END",
|
||||
body="Hello world END extra",
|
||||
yielded="Hello ",
|
||||
)
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# unyielded delta, content_block_stop, message_delta, message_stop
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == [
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
]
|
||||
assert payloads[0]["delta"]["text"] == "world "
|
||||
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
|
||||
assert payloads[2]["delta"]["stop_sequence"] == "END"
|
||||
|
||||
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo(
|
||||
matched="END",
|
||||
body="Hello END",
|
||||
yielded="Hello ",
|
||||
)
|
||||
events = builder.format_stream_end(ctx, stop)
|
||||
payloads = _sse_payloads(events)
|
||||
# No unyielded delta (everything already sent)
|
||||
types = [p["type"] for p in payloads]
|
||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
||||
|
||||
def test_format_response_with_stop_trims_content(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
|
||||
resp = builder.format_response(ctx, "text STOP extra", stop)
|
||||
assert resp["content"][0]["text"] == "text "
|
||||
assert resp["stop_reason"] == "stop_sequence"
|
||||
assert resp["stop_sequence"] == "STOP"
|
||||
|
||||
def test_format_response_no_stop(self, builder):
|
||||
ctx = _make_ctx()
|
||||
stop = StopInfo()
|
||||
resp = builder.format_response(ctx, "full text", stop)
|
||||
assert resp["content"][0]["text"] == "full text"
|
||||
assert resp["stop_reason"] == "end_turn"
|
||||
|
||||
|
||||
class TestGenerationRequestValidation:
|
||||
def test_valid_params(self):
|
||||
gr = GenerationRequest(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
temperature=0.7,
|
||||
)
|
||||
assert gr.top_k == 50
|
||||
|
||||
def test_invalid_top_p_raises(self):
|
||||
with pytest.raises(ValueError, match="top_p"):
|
||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
|
||||
|
||||
def test_invalid_top_k_raises(self):
|
||||
with pytest.raises(ValueError, match="top_k"):
|
||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
|
||||
|
||||
def test_invalid_temperature_raises(self):
|
||||
with pytest.raises(ValueError, match="temperature"):
|
||||
GenerationRequest(
|
||||
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
|
||||
)
|
||||
|
||||
def test_top_k_zero_valid(self):
|
||||
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
|
||||
assert gr.top_k == 0
|
||||
|
|
@ -173,21 +173,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|||
for stats in results["stats"]:
|
||||
assert "total_tasks" in stats
|
||||
assert stats["total_tasks"] >= 0
|
||||
|
||||
|
||||
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
|
||||
"""Tasks whose entire prompt is cached skip the prefill phase."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
max_batch_size=4,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
|
||||
scheduler.stop()
|
||||
assert task_id.startswith("task_")
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from astrai.inference import get_app
|
||||
from astrai.inference import app
|
||||
|
||||
|
||||
def test_health_no_model(client):
|
||||
"""GET /health should return 200 even when engine not loaded."""
|
||||
get_app().state.engine = None
|
||||
app.state.engine = None
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
|
|||
async def async_gen():
|
||||
yield "Assistant reply"
|
||||
|
||||
get_app().state.engine = loaded_model
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
|
|||
yield "cumulative1"
|
||||
yield "cumulative2"
|
||||
|
||||
get_app().state.engine = loaded_model
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
|
|||
async def async_gen():
|
||||
yield "Assistant reply"
|
||||
|
||||
get_app().state.engine = loaded_model
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
|
|
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
|
|||
yield "cumulative1"
|
||||
yield "cumulative2"
|
||||
|
||||
get_app().state.engine = loaded_model
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
|
|
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
|
|||
async def async_gen():
|
||||
yield "Reply"
|
||||
|
||||
get_app().state.engine = loaded_model
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
|
|
@ -165,7 +165,7 @@ def test_chat_completions_stop_sequence(client, loaded_model):
|
|||
yield "X"
|
||||
yield "world"
|
||||
|
||||
get_app().state.engine = loaded_model
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
@ -191,7 +191,7 @@ def test_chat_completions_stop_sequence_stream(client, loaded_model):
|
|||
yield "X"
|
||||
yield "world"
|
||||
|
||||
get_app().state.engine = loaded_model
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
|
|||
|
|
@ -1,355 +0,0 @@
|
|||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.lora import (
|
||||
LoRAConfig,
|
||||
LoRALinear,
|
||||
_collect_lora_info,
|
||||
_get_lora_count,
|
||||
inject_lora,
|
||||
load_lora,
|
||||
merge_lora,
|
||||
save_lora,
|
||||
)
|
||||
|
||||
MODEL_KWARGS = dict(
|
||||
vocab_size=1000,
|
||||
dim=64,
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
dim_ffn=128,
|
||||
n_layers=2,
|
||||
max_len=32,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
def _make_model(**kwargs):
|
||||
kw = {**MODEL_KWARGS, **kwargs}
|
||||
config = AutoRegressiveLMConfig(**kw)
|
||||
model = AutoRegressiveLM(config)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def test_loralinear_init():
|
||||
base = Linear(64, 128)
|
||||
lora = LoRALinear(base, r=8, alpha=16)
|
||||
|
||||
assert lora.weight is base.weight
|
||||
assert not lora.weight.requires_grad
|
||||
assert lora.lora_A.shape == (8, 64)
|
||||
assert lora.lora_B.shape == (128, 8)
|
||||
assert lora.scaling == 2.0
|
||||
assert not lora._merged
|
||||
assert lora.lora_A.requires_grad
|
||||
assert lora.lora_B.requires_grad
|
||||
|
||||
|
||||
def test_loralinear_forward_init_zero_delta():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
base_out = base(x)
|
||||
lora_out = lora(x)
|
||||
|
||||
assert torch.allclose(base_out, lora_out)
|
||||
|
||||
|
||||
def test_loralinear_forward_with_delta():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
base_out = base(x)
|
||||
|
||||
with torch.no_grad():
|
||||
lora.lora_B.fill_(1.0)
|
||||
|
||||
lora_out = lora(x)
|
||||
assert not torch.allclose(base_out, lora_out)
|
||||
|
||||
|
||||
def test_loralinear_merge():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
with torch.no_grad():
|
||||
lora.lora_B.fill_(1.0)
|
||||
|
||||
out_before = lora(x).clone()
|
||||
lora.merge()
|
||||
out_after = lora(x)
|
||||
|
||||
torch.testing.assert_close(out_before, out_after)
|
||||
assert lora._merged
|
||||
assert not hasattr(lora, "lora_A")
|
||||
|
||||
|
||||
def test_loralinear_merge_is_idempotent():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
with torch.no_grad():
|
||||
lora.lora_B.fill_(1.0)
|
||||
|
||||
lora.merge()
|
||||
lora.merge()
|
||||
|
||||
|
||||
def test_inject_lora_default_target():
|
||||
model = _make_model()
|
||||
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
|
||||
|
||||
inject_lora(model, r=4, alpha=8)
|
||||
|
||||
lora_count = _get_lora_count(model)
|
||||
assert lora_count > 0
|
||||
assert lora_count < n_before
|
||||
|
||||
|
||||
def test_inject_lora_ffn():
|
||||
model = _make_model()
|
||||
from astrai.model.components.lora import TARGET_MODULES_FFN
|
||||
|
||||
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
|
||||
assert _get_lora_count(model) > 0
|
||||
|
||||
|
||||
def test_inject_lora_returns_config():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=8, alpha=32)
|
||||
assert isinstance(cfg, LoRAConfig)
|
||||
assert cfg.r == 8
|
||||
assert cfg.alpha == 32
|
||||
|
||||
|
||||
def test_inject_lora_no_matching_targets_warns(caplog):
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
|
||||
assert "No LoRA layers injected" in caplog.text
|
||||
|
||||
|
||||
def test_inject_lora_preserves_base_output():
|
||||
model = _make_model()
|
||||
x = torch.randint(0, 1000, (2, 16))
|
||||
|
||||
with torch.no_grad():
|
||||
out_before = model(x)["logits"].clone()
|
||||
|
||||
inject_lora(model, r=4, alpha=8)
|
||||
|
||||
with torch.no_grad():
|
||||
out_after = model(x)["logits"]
|
||||
|
||||
torch.testing.assert_close(out_before, out_after)
|
||||
|
||||
|
||||
def test_inject_lora_does_not_reinject():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
first_count = _get_lora_count(model)
|
||||
|
||||
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
|
||||
assert _get_lora_count(model) == first_count
|
||||
|
||||
|
||||
def test_inject_lora_adds_new_modules():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
first = _get_lora_count(model)
|
||||
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
|
||||
assert _get_lora_count(model) > first
|
||||
|
||||
|
||||
def test_inject_lora_on_mla_model():
|
||||
model = _make_model(
|
||||
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
|
||||
)
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
|
||||
assert _get_lora_count(model) > 0
|
||||
|
||||
|
||||
def test_inject_lora_on_moe_model():
|
||||
model = _make_model(
|
||||
ffn_type="moe",
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
n_activated_experts=2,
|
||||
dim_ffn=32,
|
||||
)
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
|
||||
assert _get_lora_count(model) > 0
|
||||
|
||||
|
||||
def test_state_dict_key_format():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
sd = model.state_dict()
|
||||
assert "layers.0.attention.q_proj.weight" in sd
|
||||
assert "layers.0.attention.q_proj.lora_A" in sd
|
||||
assert "layers.0.attention.q_proj.lora_B" in sd
|
||||
|
||||
|
||||
def test_only_lora_params_trainable():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(name.split(".")[-1], str) and "lora" in name:
|
||||
assert param.requires_grad, f"lora param should be trainable: {name}"
|
||||
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
|
||||
assert not param.requires_grad, f"injected weight should be frozen: {name}"
|
||||
|
||||
|
||||
def test_state_dict_after_inject_consistent_with_original():
|
||||
model = _make_model()
|
||||
sd_before = {k: v for k, v in model.state_dict().items()}
|
||||
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
sd_after = model.state_dict()
|
||||
|
||||
# original keys unchanged
|
||||
for k in sd_before:
|
||||
assert k in sd_after
|
||||
assert sd_before[k].shape == sd_after[k].shape
|
||||
|
||||
# new lora keys present
|
||||
lora_keys = [k for k in sd_after if "lora" in k]
|
||||
assert len(lora_keys) > 0
|
||||
|
||||
|
||||
def test_save_load_roundtrip():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 16))
|
||||
with torch.no_grad():
|
||||
out_src = model(x)["logits"].clone()
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, cfg)
|
||||
|
||||
model2 = _make_model()
|
||||
model2.load_state_dict(model.state_dict(), strict=False)
|
||||
load_lora(model2, tmpdir)
|
||||
|
||||
with torch.no_grad():
|
||||
out_dst = model2(x)["logits"]
|
||||
|
||||
torch.testing.assert_close(out_src, out_dst)
|
||||
|
||||
|
||||
def test_save_after_merge_raises():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, cfg)
|
||||
merge_lora(model)
|
||||
|
||||
tmpdir2 = tempfile.mkdtemp()
|
||||
with pytest.raises(RuntimeError, match="No LoRA parameters"):
|
||||
save_lora(model, tmpdir2, cfg)
|
||||
|
||||
|
||||
def test_load_lora_on_already_injected():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
|
||||
|
||||
model2 = _make_model()
|
||||
model2.load_state_dict(model.state_dict(), strict=False)
|
||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
# load onto already-injected model
|
||||
load_lora(model2, tmpdir)
|
||||
assert _get_lora_count(model2) > 0
|
||||
|
||||
|
||||
def test_load_lora_mismatched_r_raises():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, cfg)
|
||||
|
||||
model2 = _make_model()
|
||||
model2.load_state_dict(model.state_dict(), strict=False)
|
||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with pytest.raises(RuntimeError, match="size mismatch"):
|
||||
load_lora(model2, tmpdir) # strict=False, only lora keys
|
||||
|
||||
|
||||
def test_merge_preserves_output():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 16))
|
||||
with torch.no_grad():
|
||||
out_before = model(x)["logits"].clone()
|
||||
|
||||
merge_lora(model)
|
||||
|
||||
with torch.no_grad():
|
||||
out_after = model(x)["logits"]
|
||||
torch.testing.assert_close(out_before, out_after)
|
||||
|
||||
|
||||
def test_merge_no_lora_warns(caplog):
|
||||
model = _make_model()
|
||||
merge_lora(model)
|
||||
assert "No LoRA layers to merge" in caplog.text
|
||||
|
||||
|
||||
def test_collect_lora_info():
|
||||
model = _make_model()
|
||||
info = _collect_lora_info(model)
|
||||
assert "q_proj" in info
|
||||
assert "o_proj" in info
|
||||
assert "q_proj" in info # each layer has one
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
|
@ -27,7 +25,7 @@ class TrainerDataset(Dataset):
|
|||
|
||||
|
||||
def create_train_config(
|
||||
model_fn,
|
||||
model: torch.nn.Module,
|
||||
dataset: Dataset,
|
||||
test_dir: str,
|
||||
device: str,
|
||||
|
|
@ -43,7 +41,7 @@ def create_train_config(
|
|||
"""Factory function to create common TrainConfig for tests.
|
||||
|
||||
Args:
|
||||
model_fn: Model factory (callable returning nn.Module)
|
||||
model: The model to train
|
||||
dataset: Training dataset
|
||||
test_dir: Checkpoint directory
|
||||
device: Device type ("cuda" or "cpu")
|
||||
|
|
@ -70,12 +68,11 @@ def create_train_config(
|
|||
|
||||
return TrainConfig(
|
||||
strategy=strategy,
|
||||
model_fn=model_fn,
|
||||
model=model,
|
||||
dataset=dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=test_dir,
|
||||
log_dir=os.path.join(test_dir, "logs"),
|
||||
n_epoch=n_epoch,
|
||||
batch_per_device=batch_per_device,
|
||||
ckpt_interval=ckpt_interval,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
|
@ -106,13 +104,12 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
model=base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||
n_epoch=1,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=3,
|
||||
|
|
@ -140,13 +137,12 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
model=base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||
n_epoch=1,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=3,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
|
@ -23,10 +24,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
strategy="seq",
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
model=base_test_env["model"],
|
||||
dataset=early_stopping_dataset,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||
n_epoch=2,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=1,
|
||||
|
|
@ -38,20 +38,17 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
trainer = Trainer(train_config)
|
||||
|
||||
# Should handle early stopping gracefully
|
||||
checkpoint = None
|
||||
try:
|
||||
trainer.train()
|
||||
checkpoint = trainer.train()
|
||||
except Exception:
|
||||
# Handle any exceptions
|
||||
pass
|
||||
|
||||
# Resume from latest checkpoint
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train(resume_dir=load_dir)
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
trainer.train(checkpoint)
|
||||
|
||||
# Verify checkpoint was saved at expected iteration
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||
import json
|
||||
|
||||
with open(os.path.join(load_dir, "meta.json")) as f:
|
||||
meta = json.load(f)
|
||||
assert meta["iteration"] == 10
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
assert checkpoint.iteration == 10
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
|||
|
||||
for batch_per_device in batch_sizes:
|
||||
train_config = train_config_factory(
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
|
|||
|
||||
for grad_accum_steps in grad_accum_steps_list:
|
||||
train_config = train_config_factory(
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
|||
|
||||
for config in small_batch_configs:
|
||||
train_config = train_config_factory(
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
model=base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue