Compare commits
No commits in common. "main" and "v1.3.6" have entirely different histories.
|
|
@ -82,7 +82,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--parallel_mode=ddp \
|
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
@ -109,8 +108,8 @@ Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py \
|
python scripts/tools/generate.py \
|
||||||
--param_path /path/to/model \
|
--param_path /path/to/model \
|
||||||
--input_json_file /path/to/input.jsonl \
|
--input_json_file /path/to/input.json \
|
||||||
--output_json_file /path/to/output.jsonl
|
--output_json_file /path/to/output.json
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
@ -225,7 +224,6 @@ Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6y
|
||||||
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
|
||||||
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
|
||||||
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
|
||||||
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
|
|
||||||
|
|
||||||
### Contributing
|
### Contributing
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--parallel_mode=ddp \
|
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
@ -115,8 +114,8 @@ nohup python scripts/tools/train.py \
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/generate.py \
|
python scripts/tools/generate.py \
|
||||||
--param_path /path/to/model \
|
--param_path /path/to/model \
|
||||||
--input_json_file /path/to/input.jsonl \
|
--input_json_file /path/to/input.json \
|
||||||
--output_json_file /path/to/output.jsonl
|
--output_json_file /path/to/output.json
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
@ -231,7 +230,6 @@ python scripts/demo/generate_ar.py
|
||||||
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
| [训练文档](./training.md) | 训练循环、策略与公式 |
|
||||||
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
|
||||||
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
|
||||||
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
|
|
||||||
|
|
||||||
### 贡献
|
### 贡献
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,6 @@ classDiagram
|
||||||
class BaseConfig {
|
class BaseConfig {
|
||||||
+to_dict() Dict
|
+to_dict() Dict
|
||||||
+from_dict(d) Self
|
+from_dict(d) Self
|
||||||
+from_json(path) Self
|
|
||||||
+to_json(path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseModelConfig {
|
class BaseModelConfig {
|
||||||
|
|
@ -19,43 +17,41 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class AutoRegressiveLMConfig {
|
class AutoRegressiveLMConfig {
|
||||||
+Optional[int] vocab_size
|
+int vocab_size
|
||||||
+Optional[int] dim
|
+int dim
|
||||||
+Optional[int] n_layers
|
+int n_layers
|
||||||
+Optional[float] norm_eps
|
+float norm_eps
|
||||||
+Optional[int] dim_ffn
|
+int dim_ffn
|
||||||
+Optional[bool] tie_weight
|
+bool tie_weight
|
||||||
+Optional[dict] rope_scaling
|
+int max_len
|
||||||
+Optional[int] max_len
|
+float rope_theta
|
||||||
+Optional[float] rope_theta
|
|
||||||
+str attn_type
|
+str attn_type
|
||||||
+Optional[int] n_heads
|
+int n_heads
|
||||||
+Optional[int] n_kv_heads
|
+int n_kv_heads
|
||||||
+Optional[bool] use_qk_norm
|
+bool use_qk_norm
|
||||||
+Optional[bool] use_gated_attention
|
+bool use_gated_attention
|
||||||
+Optional[int] kv_lora_rank
|
+Optional[int] kv_lora_rank
|
||||||
+Optional[int] qk_nope_head_dim
|
+Optional[int] qk_nope_head_dim
|
||||||
+Optional[int] qk_rope_head_dim
|
+Optional[int] qk_rope_head_dim
|
||||||
+str ffn_type
|
+str ffn_type
|
||||||
+Optional[int] n_routed_experts
|
+int n_routed_experts
|
||||||
+Optional[int] n_shared_experts
|
+int n_shared_experts
|
||||||
+Optional[int] n_activated_experts
|
+int n_activated_experts
|
||||||
+Optional[str] topk_method
|
+Optional[str] topk_method
|
||||||
}
|
}
|
||||||
|
|
||||||
class EncoderConfig {
|
class EncoderConfig {
|
||||||
+Optional[int] vocab_size
|
+int vocab_size
|
||||||
+Optional[int] dim
|
+int dim
|
||||||
+Optional[int] n_layers
|
+int n_layers
|
||||||
+Optional[float] norm_eps
|
+float norm_eps
|
||||||
+Optional[int] dim_ffn
|
+int dim_ffn
|
||||||
+Optional[int] max_len
|
+int max_len
|
||||||
+Optional[float] rope_theta
|
+float rope_theta
|
||||||
+Optional[int] n_heads
|
+int n_heads
|
||||||
+Optional[int] n_kv_heads
|
+int n_kv_heads
|
||||||
+Optional[bool] use_qk_norm
|
+bool use_qk_norm
|
||||||
+Optional[bool] use_gated_attention
|
+bool use_gated_attention
|
||||||
+Optional[dict] rope_scaling
|
|
||||||
+Optional[str] pooling_type
|
+Optional[str] pooling_type
|
||||||
+Optional[bool] normalize_embeddings
|
+Optional[bool] normalize_embeddings
|
||||||
}
|
}
|
||||||
|
|
@ -66,40 +62,8 @@ classDiagram
|
||||||
+load(raw) BaseConfig
|
+load(raw) BaseConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
class InputConfig {
|
|
||||||
+str type
|
|
||||||
+str messages_key
|
|
||||||
+str prompt_key
|
|
||||||
+str response_key
|
|
||||||
+str text_key
|
|
||||||
}
|
|
||||||
|
|
||||||
class ProcessingConfig {
|
|
||||||
+int max_seq_len
|
|
||||||
+int min_chars
|
|
||||||
+int max_chars
|
|
||||||
+bool deduplicate
|
|
||||||
+Optional[int] max_items
|
|
||||||
}
|
|
||||||
|
|
||||||
class OutputConfig {
|
|
||||||
+Optional[str] domain_key
|
|
||||||
+str storage_format
|
|
||||||
+int max_tokens_per_shard
|
|
||||||
}
|
|
||||||
|
|
||||||
class PipelineConfig {
|
|
||||||
+int version
|
|
||||||
+InputConfig input
|
|
||||||
+dict mask
|
|
||||||
+str mask_default
|
|
||||||
+ProcessingConfig preprocessing
|
|
||||||
+OutputConfig output
|
|
||||||
+from_dict(d) Self
|
|
||||||
}
|
|
||||||
|
|
||||||
class TrainConfig {
|
class TrainConfig {
|
||||||
+Callable[[], nn.Module] model_fn
|
+nn.Module model
|
||||||
+str strategy
|
+str strategy
|
||||||
+Dataset dataset
|
+Dataset dataset
|
||||||
+Callable optimizer_fn
|
+Callable optimizer_fn
|
||||||
|
|
@ -116,7 +80,6 @@ classDiagram
|
||||||
+str log_dir
|
+str log_dir
|
||||||
+int log_interval
|
+int log_interval
|
||||||
+List[str] metrics
|
+List[str] metrics
|
||||||
+Optional[LoRAConfig] lora
|
|
||||||
+int random_seed
|
+int random_seed
|
||||||
+int num_workers
|
+int num_workers
|
||||||
+Optional[int] prefetch_factor
|
+Optional[int] prefetch_factor
|
||||||
|
|
@ -125,12 +88,12 @@ classDiagram
|
||||||
+str backend
|
+str backend
|
||||||
+str master_addr
|
+str master_addr
|
||||||
+str master_port
|
+str master_port
|
||||||
|
+Callable parallel_wrapper
|
||||||
|
+Callable state_dict_fn
|
||||||
+str start_method
|
+str start_method
|
||||||
+str device_type
|
+str device_type
|
||||||
+Optional[Dataset] val_dataset
|
+Optional[Dataset] val_dataset
|
||||||
+int val_step
|
+int val_step
|
||||||
+str parallel_mode
|
|
||||||
+dict executor_kwargs
|
|
||||||
+dict extra_kwargs
|
+dict extra_kwargs
|
||||||
+validate()
|
+validate()
|
||||||
}
|
}
|
||||||
|
|
@ -141,8 +104,8 @@ classDiagram
|
||||||
class BaseDataset {
|
class BaseDataset {
|
||||||
+int window_size
|
+int window_size
|
||||||
+int stride
|
+int stride
|
||||||
+Optional[Store] storage
|
+Optional[BaseStorage] storage
|
||||||
+load(load_path, storage_type)
|
+load(load_path, storage_type, tokenizer)
|
||||||
+__getitem__(index)
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
}
|
}
|
||||||
|
|
@ -163,25 +126,38 @@ classDiagram
|
||||||
+__getitem__(index) Dict
|
+__getitem__(index) Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class Store {
|
class BaseSegmentFetcher {
|
||||||
+Dict[str, List[Tensor]] _data
|
+List[Tensor] segments
|
||||||
+Dict[str, List[int]] _cum
|
+List[int] cum_lengths
|
||||||
+int _length
|
+int total_length
|
||||||
|
+fetch_data(begin_idx, end_idx) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class BaseStorage {
|
||||||
|
+MultiSegmentFetcher _fetcher
|
||||||
+keys (property)
|
+keys (property)
|
||||||
+load(path)
|
+load(load_path, tokenizer)
|
||||||
+fetch(begin, end, keys)
|
+fetch(begin, end, keys)
|
||||||
+__len__()
|
+__len__()
|
||||||
-_fetch_key(key, begin, end) Tensor
|
|
||||||
-_normalize(raw)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class H5Store {
|
class H5Storage {
|
||||||
+load(path)
|
+load(load_path, tokenizer)
|
||||||
|
+fetch(begin, end, keys) Dict
|
||||||
|
+keys() List
|
||||||
}
|
}
|
||||||
|
|
||||||
class MmapStore {
|
class JSONStorage {
|
||||||
+List _mmap_refs
|
+load(load_path, tokenizer)
|
||||||
+load(path)
|
+fetch(begin, end, keys) Dict
|
||||||
|
+keys() List
|
||||||
|
}
|
||||||
|
|
||||||
|
class MultiSegmentFetcher {
|
||||||
|
+Dict multi_fetchers
|
||||||
|
+List multi_keys
|
||||||
|
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||||
|
+fetch_data(begin_idx, end_idx) Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
|
|
@ -189,17 +165,17 @@ classDiagram
|
||||||
+int iter
|
+int iter
|
||||||
}
|
}
|
||||||
|
|
||||||
class StoreFactory {
|
class StorageFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(storage_type) Store
|
+create(storage_type) BaseStorage
|
||||||
}
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
|
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -210,9 +186,8 @@ classDiagram
|
||||||
+int iteration
|
+int iteration
|
||||||
+dict extra
|
+dict extra
|
||||||
+dict meta
|
+dict meta
|
||||||
+dict config
|
|
||||||
+save(save_dir)
|
+save(save_dir)
|
||||||
+load(save_dir, broadcast) Checkpoint
|
+load(save_dir) Checkpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -220,8 +195,8 @@ classDiagram
|
||||||
class AutoModel {
|
class AutoModel {
|
||||||
+BaseModelConfig config
|
+BaseModelConfig config
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(model_type) decorator
|
||||||
+get_component_class(name) Type
|
+get_component_class(model_type) Type
|
||||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||||
+save_pretrained(save_directory)
|
+save_pretrained(save_directory)
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
|
|
@ -235,7 +210,7 @@ classDiagram
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
+Linear lm_head
|
+Linear lm_head
|
||||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
||||||
+load_state_dict(state_dict, strict, assign)
|
+load_state_dict(state_dict)
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -260,7 +235,6 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class GQA {
|
class GQA {
|
||||||
+int dim
|
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
|
|
@ -275,7 +249,6 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class MLA {
|
class MLA {
|
||||||
+int dim
|
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
|
|
@ -284,13 +257,11 @@ classDiagram
|
||||||
+int qk_rope_head_dim
|
+int qk_rope_head_dim
|
||||||
+int n_rep
|
+int n_rep
|
||||||
+int layer_id
|
+int layer_id
|
||||||
+bool use_qk_norm
|
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
+Linear q_proj, kv_a_proj, kv_b_proj
|
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||||
+Linear o_proj
|
+Linear o_proj
|
||||||
+Linear gate # only if use_gated_attention
|
+Linear gate # only if use_gated_attention
|
||||||
+RMSNorm kv_norm
|
+RMSNorm kv_norm
|
||||||
+RMSNorm q_norm, k_norm # only if use_qk_norm
|
|
||||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -336,7 +307,6 @@ classDiagram
|
||||||
+int dim
|
+int dim
|
||||||
+int max_len
|
+int max_len
|
||||||
+float base
|
+float base
|
||||||
+Optional[Dict] rope_scaling
|
|
||||||
+forward(x, position_ids=None) Tensor
|
+forward(x, position_ids=None) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -346,42 +316,13 @@ classDiagram
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace preprocessing {
|
|
||||||
class BaseMaskBuilder {
|
|
||||||
<<abstract>>
|
|
||||||
+build(item, config, tokenizer) Optional[dict]
|
|
||||||
}
|
|
||||||
|
|
||||||
class ChatMaskBuilder {
|
|
||||||
+build(item, config, tokenizer) Optional[dict]
|
|
||||||
}
|
|
||||||
|
|
||||||
class InstructionMaskBuilder {
|
|
||||||
+build(item, config, tokenizer) Optional[dict]
|
|
||||||
}
|
|
||||||
|
|
||||||
class TextMaskBuilder {
|
|
||||||
+build(item, config, tokenizer) Optional[dict]
|
|
||||||
}
|
|
||||||
|
|
||||||
class Pipeline {
|
|
||||||
+PipelineConfig config
|
|
||||||
+List[str] paths
|
|
||||||
+str output_dir
|
|
||||||
+str tokenizer_path
|
|
||||||
+BaseMaskBuilder mask_builder
|
|
||||||
+transform(item) Optional[dict]
|
|
||||||
+run()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List
|
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||||
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
|
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||||
+set_chat_template(template)
|
+set_chat_template(template)
|
||||||
+load(path)
|
+load(path)
|
||||||
+from_pretrained(path) AutoTokenizer
|
+from_pretrained(path) AutoTokenizer
|
||||||
|
|
@ -389,7 +330,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatTemplate {
|
class ChatTemplate {
|
||||||
+str template_str
|
+String template_str
|
||||||
+render(messages, system_prompt, **extra_variables) str
|
+render(messages, system_prompt, **extra_variables) str
|
||||||
+from_string(template) ChatTemplate
|
+from_string(template) ChatTemplate
|
||||||
}
|
}
|
||||||
|
|
@ -409,32 +350,24 @@ classDiagram
|
||||||
+create(name, *args, **kwargs) T
|
+create(name, *args, **kwargs) T
|
||||||
+list_registered() list
|
+list_registered() list
|
||||||
}
|
}
|
||||||
|
|
||||||
class MaskBuilderFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+create(input_type, config, tokenizer) BaseMaskBuilder
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace trainer {
|
namespace trainer {
|
||||||
class Trainer {
|
class Trainer {
|
||||||
+TrainConfig train_config
|
+TrainConfig train_config
|
||||||
+List[TrainCallback] callbacks
|
+List[TrainCallback] callbacks
|
||||||
+train(resume_dir)
|
+train(checkpoint)
|
||||||
-_get_default_callbacks() List[TrainCallback]
|
+_get_default_callbacks() List[TrainCallback]
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContext {
|
class TrainContext {
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+BaseStrategy strategy
|
+BaseStrategy strategy
|
||||||
+DataLoader dataloader
|
+DataLoader dataloader
|
||||||
+OptimizerProtocol optimizer
|
+Optimizer optimizer
|
||||||
+SchedulerProtocol scheduler
|
+LRScheduler scheduler
|
||||||
+Checkpoint checkpoint
|
+Checkpoint checkpoint
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
+dict model_config
|
|
||||||
+BaseExecutor executor
|
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
+float loss
|
+float loss
|
||||||
|
|
@ -447,17 +380,13 @@ classDiagram
|
||||||
|
|
||||||
class TrainContextBuilder {
|
class TrainContextBuilder {
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
+with_resume_dir(resume_dir) TrainContextBuilder
|
+with_checkpoint(checkpoint) TrainContextBuilder
|
||||||
+build() TrainContext
|
+build() TrainContext
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseStrategy {
|
class BaseStrategy {
|
||||||
+Callable model
|
+Union[Callable, nn.Module] model
|
||||||
+Optional[BaseExecutor] executor
|
|
||||||
+Optional[Callable] model_fn
|
|
||||||
+dict extra_kwargs
|
|
||||||
+str device
|
+str device
|
||||||
+__call__(batch) Tensor
|
|
||||||
+compute_loss(batch) Tensor
|
+compute_loss(batch) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -498,8 +427,6 @@ classDiagram
|
||||||
class BaseScheduler {
|
class BaseScheduler {
|
||||||
+get_lr() List[float]
|
+get_lr() List[float]
|
||||||
+step()
|
+step()
|
||||||
+state_dict() dict
|
|
||||||
+load_state_dict(d)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class SchedulerFactory {
|
class SchedulerFactory {
|
||||||
|
|
@ -511,7 +438,6 @@ classDiagram
|
||||||
class CosineScheduler {
|
class CosineScheduler {
|
||||||
+int warmup_steps
|
+int warmup_steps
|
||||||
+int lr_decay_steps
|
+int lr_decay_steps
|
||||||
+int total_steps
|
|
||||||
+float min_rate
|
+float min_rate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -528,15 +454,16 @@ classDiagram
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_epoch_begin(context)
|
+on_epoch_begin(context)
|
||||||
+on_epoch_end(context)
|
+on_epoch_end(context)
|
||||||
|
+on_step_begin(context)
|
||||||
|
+on_step_end(context)
|
||||||
+on_batch_begin(context)
|
+on_batch_begin(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_optimizer_step(context)
|
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class GradientClippingCallback {
|
class GradientClippingCallback {
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+on_optimizer_step(context)
|
+on_step_begin(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class GradientCheckpointingCallback {
|
class GradientCheckpointingCallback {
|
||||||
|
|
@ -549,12 +476,16 @@ classDiagram
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
+bool weight_only
|
+bool weight_only
|
||||||
|
+Callable state_dict_fn
|
||||||
+Callable save_extra_fn
|
+Callable save_extra_fn
|
||||||
-_save_checkpoint(context)
|
+Callable load_extra_fn
|
||||||
|
+_save_checkpoint(context)
|
||||||
|
+on_train_begin(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
+save_extra(context) dict$
|
+save_extra(context)$
|
||||||
|
+load_extra(extra, context)$
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProgressBarCallback {
|
class ProgressBarCallback {
|
||||||
|
|
@ -567,7 +498,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class MetricLoggerCallback {
|
class MetricLoggerCallback {
|
||||||
+Path log_dir
|
+str log_dir
|
||||||
+int save_interval
|
+int save_interval
|
||||||
+int log_interval
|
+int log_interval
|
||||||
+List[str] metrics
|
+List[str] metrics
|
||||||
|
|
@ -577,8 +508,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ValidationCallback {
|
class ValidationCallback {
|
||||||
-_run_validation(context)
|
+_run_validation(context)
|
||||||
+on_optimizer_step(context)
|
+on_step_end(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class CallbackFactory {
|
class CallbackFactory {
|
||||||
|
|
@ -591,12 +522,7 @@ classDiagram
|
||||||
+float lr
|
+float lr
|
||||||
+float momentum
|
+float momentum
|
||||||
+float weight_decay
|
+float weight_decay
|
||||||
+bool nesterov
|
|
||||||
+int ns_steps
|
+int ns_steps
|
||||||
+Optional[float] adamw_lr
|
|
||||||
+tuple adamw_betas
|
|
||||||
+float adamw_eps
|
|
||||||
+float adamw_wd
|
|
||||||
+step(closure) Optional[float]
|
+step(closure) Optional[float]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -617,8 +543,6 @@ classDiagram
|
||||||
+AutoModel model
|
+AutoModel model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+KVCache page_cache
|
+KVCache page_cache
|
||||||
+Optional[str] device
|
|
||||||
+Optional[torch.dtype] dtype
|
|
||||||
+execute_prefill(tasks, prompt_len, start_pos)
|
+execute_prefill(tasks, prompt_len, start_pos)
|
||||||
+execute_decode(tasks) List[int]
|
+execute_decode(tasks) List[int]
|
||||||
}
|
}
|
||||||
|
|
@ -630,9 +554,7 @@ classDiagram
|
||||||
+bool _running
|
+bool _running
|
||||||
+Thread _loop_thread
|
+Thread _loop_thread
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+str device
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+torch.dtype dtype
|
|
||||||
+add_task(prompt, **kwargs) str
|
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
+stop()
|
+stop()
|
||||||
|
|
@ -710,7 +632,7 @@ classDiagram
|
||||||
class Task {
|
class Task {
|
||||||
+str task_id
|
+str task_id
|
||||||
+List prompt_ids
|
+List prompt_ids
|
||||||
+Optional[int] max_tokens
|
+int max_tokens
|
||||||
+float temperature
|
+float temperature
|
||||||
+float top_p
|
+float top_p
|
||||||
+int top_k
|
+int top_k
|
||||||
|
|
@ -719,8 +641,8 @@ classDiagram
|
||||||
+int input_tokens
|
+int input_tokens
|
||||||
+int output_tokens
|
+int output_tokens
|
||||||
+float arrival_time
|
+float arrival_time
|
||||||
+Optional[float] finish_time
|
+float finish_time
|
||||||
+Optional[Callable] stream_callback
|
+Callable stream_callback
|
||||||
+int next_pos
|
+int next_pos
|
||||||
+is_finished(stop_ids) bool
|
+is_finished(stop_ids) bool
|
||||||
}
|
}
|
||||||
|
|
@ -735,24 +657,15 @@ classDiagram
|
||||||
|
|
||||||
class TaskManager {
|
class TaskManager {
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+int max_batch_size
|
|
||||||
+int max_seq_len
|
|
||||||
+int max_prompt_len
|
|
||||||
+Deque waiting_queue
|
+Deque waiting_queue
|
||||||
+List active_tasks
|
+List active_tasks
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, **kwargs) str
|
||||||
+remove_task(task_id) List[Task]
|
+remove_task(task_id) List[Task]
|
||||||
+remove_finished_tasks(stop_ids) List[Task]
|
+remove_finished_tasks(stop_ids) List[Task]
|
||||||
+pull_candidates(n) List[Task]
|
+pull_candidates(n) List[Task]
|
||||||
+activate(task)
|
+activate(task)
|
||||||
+return_to_waiting(tasks)
|
+return_to_waiting(tasks)
|
||||||
+get_active_tasks() List[Task]
|
+get_active_tasks() List[Task]
|
||||||
+has_work() bool
|
|
||||||
+wait_for_tasks(timeout)
|
|
||||||
+get_waiting_tasks() List[Task]
|
|
||||||
+clear_queues()
|
|
||||||
+wake()
|
|
||||||
+get_stats() Dict
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenerationRequest {
|
class GenerationRequest {
|
||||||
|
|
@ -831,65 +744,56 @@ classDiagram
|
||||||
+str model
|
+str model
|
||||||
+List[AnthropicMessage] messages
|
+List[AnthropicMessage] messages
|
||||||
+Optional[str] system
|
+Optional[str] system
|
||||||
+Optional[float] temperature
|
+float temperature
|
||||||
+Optional[float] top_p
|
+float top_p
|
||||||
+Optional[int] top_k
|
+int top_k
|
||||||
+int max_tokens
|
+int max_tokens
|
||||||
+Optional[bool] stream
|
+bool stream
|
||||||
+Optional[List[str]] stop_sequences
|
+Optional[List[str]] stop_sequences
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResponseBuilder {
|
|
||||||
<<abstract>>
|
|
||||||
+prepare(request, tokenizer) Tuple[str, GenContext, List[str]]
|
|
||||||
+format_stream_start(ctx) List[str]
|
|
||||||
+format_chunk(token) str
|
|
||||||
+format_stream_end(ctx, stop) List[str]
|
|
||||||
+format_response(ctx, content, stop) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class OpenAIResponseBuilder {
|
|
||||||
+prepare(request, tokenizer) Tuple
|
|
||||||
+format_stream_start(ctx) List[str]
|
|
||||||
+format_chunk(token) str
|
|
||||||
+format_stream_end(ctx, stop) List[str]
|
|
||||||
+format_response(ctx, content, stop) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class AnthropicResponseBuilder {
|
|
||||||
+prepare(request, tokenizer) Tuple
|
|
||||||
+format_stream_start(ctx) List[str]
|
|
||||||
+format_chunk(token) str
|
|
||||||
+format_stream_end(ctx, stop) List[str]
|
|
||||||
+format_response(ctx, content, stop) Dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class ProtocolHandler {
|
class ProtocolHandler {
|
||||||
|
<<abstract>>
|
||||||
+request
|
+request
|
||||||
+engine
|
+engine
|
||||||
+builder: ResponseBuilder
|
+build_prompt() str
|
||||||
+async handle() Union[StreamingResponse, Dict]
|
+create_response_id() str
|
||||||
-_handle_stream(agen, ctx, stop_sequences) StreamingResponse
|
+get_stop_sequences() List[str]
|
||||||
-async _handle_non_stream(agen, ctx, stop_sequences) Dict
|
+create_stop_checker() StopChecker
|
||||||
|
+on_token(ctx, token, stop_checker) Optional[str]
|
||||||
|
+format_stream_start(ctx) List[str]
|
||||||
|
+format_stream_token(ctx, token) str
|
||||||
|
+format_stream_end(ctx) List[str]
|
||||||
|
+format_non_stream_response(ctx, content) Dict
|
||||||
|
+handle() Union[StreamingResponse, Dict]
|
||||||
|
}
|
||||||
|
|
||||||
|
class OpenAIHandler {
|
||||||
|
+build_prompt() str
|
||||||
|
+create_response_id() str
|
||||||
|
}
|
||||||
|
|
||||||
|
class AnthropicHandler {
|
||||||
|
+build_prompt() str
|
||||||
|
+create_response_id() str
|
||||||
|
+on_token(ctx, token, stop_checker) Optional[str]
|
||||||
}
|
}
|
||||||
|
|
||||||
class StopChecker {
|
class StopChecker {
|
||||||
+__init__(sequences)
|
+has_sequences (property) bool
|
||||||
+check(text) Optional[str]
|
+check(text) Optional[str]
|
||||||
|
+trim(text, matched) str
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenContext {
|
class StreamContext {
|
||||||
+str resp_id
|
+str resp_id
|
||||||
+int created
|
+int created
|
||||||
+str model
|
+str model
|
||||||
+int prompt_tokens
|
+int prompt_tokens
|
||||||
+int completion_tokens
|
+int completion_tokens
|
||||||
}
|
+str accumulated
|
||||||
|
+Optional[str] stop_matched
|
||||||
class StopInfo {
|
+str last_yield_trimmed
|
||||||
+Optional[str] matched
|
|
||||||
+str body
|
|
||||||
+str yielded
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class app {
|
class app {
|
||||||
|
|
@ -898,87 +802,15 @@ classDiagram
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace protocols {
|
|
||||||
class OptimizerProtocol {
|
|
||||||
<<protocol>>
|
|
||||||
+step(closure)
|
|
||||||
+zero_grad()
|
|
||||||
+state_dict() dict
|
|
||||||
+load_state_dict(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
class SchedulerProtocol {
|
|
||||||
<<protocol>>
|
|
||||||
+step()
|
|
||||||
+state_dict() dict
|
|
||||||
+load_state_dict(d)
|
|
||||||
+get_last_lr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
class setup {
|
class Functions {
|
||||||
<<module>>
|
<<module>>
|
||||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
||||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) contextmanager
|
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||||
+get_current_device() str
|
+get_current_device() str
|
||||||
+get_world_size() int
|
+get_world_size() int
|
||||||
+get_rank() int
|
+get_rank() int
|
||||||
+only_on_rank(rank, sync=False) decorator
|
+only_on_rank(rank, sync) decorator
|
||||||
}
|
|
||||||
|
|
||||||
class GradientState {
|
|
||||||
+int num_steps
|
|
||||||
+sync_gradients (property) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
class AccumOptimizer {
|
|
||||||
+Optimizer optimizer
|
|
||||||
+GradientState gradient_state
|
|
||||||
+param_groups (property)
|
|
||||||
+step(closure)
|
|
||||||
+zero_grad()
|
|
||||||
+state_dict() dict
|
|
||||||
+load_state_dict(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
class AccumScheduler {
|
|
||||||
+LRScheduler scheduler
|
|
||||||
+GradientState gradient_state
|
|
||||||
+step()
|
|
||||||
+state_dict() dict
|
|
||||||
+load_state_dict(d)
|
|
||||||
+get_last_lr()
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseExecutor {
|
|
||||||
+GradientState gradient_state
|
|
||||||
+prepare(model, optimizer, dataloader, scheduler) tuple
|
|
||||||
+accumulate(model) context manager
|
|
||||||
+backward(loss)
|
|
||||||
+unwrap_model(model) dict
|
|
||||||
+sync_gradients (property) bool
|
|
||||||
+grad_accum_steps (property) int
|
|
||||||
}
|
|
||||||
|
|
||||||
class NoneExecutor {
|
|
||||||
}
|
|
||||||
|
|
||||||
class DDPExecutor {
|
|
||||||
-_prepare_model(model) nn.Module
|
|
||||||
-_no_sync(model) context manager
|
|
||||||
+unwrap_model(model) dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class FSDPExecutor {
|
|
||||||
-_prepare_model(model) nn.Module
|
|
||||||
+unwrap_model(model) dict
|
|
||||||
}
|
|
||||||
|
|
||||||
class ExecutorFactory {
|
|
||||||
+Registry _registry
|
|
||||||
+register(name) decorator
|
|
||||||
+create(parallel_mode, **kwargs) BaseExecutor
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParallelModel {
|
class ParallelModel {
|
||||||
|
|
@ -988,25 +820,11 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ColumnParallelLinear {
|
class ColumnParallelLinear {
|
||||||
+int in_features
|
|
||||||
+int out_features
|
|
||||||
+int out_features_per_rank
|
|
||||||
+bool gather_results
|
|
||||||
+Parameter weight
|
|
||||||
+Optional[Parameter] bias
|
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
+load_state_dict(state_dict)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class RowParallelLinear {
|
class RowParallelLinear {
|
||||||
+int in_features
|
|
||||||
+int out_features
|
|
||||||
+int in_features_per_rank
|
|
||||||
+bool reduce_results
|
|
||||||
+Parameter weight
|
|
||||||
+Optional[Parameter] bias
|
|
||||||
+forward(x) Tensor
|
+forward(x) Tensor
|
||||||
+load_state_dict(state_dict)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1024,13 +842,12 @@ classDiagram
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
TrainCallback <|-- ValidationCallback
|
|
||||||
BaseDataset <|-- SEQDataset
|
BaseDataset <|-- SEQDataset
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
BaseDataset <|-- GRPODataset
|
BaseDataset <|-- GRPODataset
|
||||||
Store <|-- H5Store
|
BaseStorage <|-- H5Storage
|
||||||
Store <|-- MmapStore
|
BaseStorage <|-- JSONStorage
|
||||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
BaseSamplingStrategy <|-- TopPStrategy
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
|
|
@ -1041,10 +858,6 @@ classDiagram
|
||||||
AutoModel <|-- EmbeddingEncoder
|
AutoModel <|-- EmbeddingEncoder
|
||||||
BaseConfig <|-- BaseModelConfig
|
BaseConfig <|-- BaseModelConfig
|
||||||
BaseConfig <|-- TrainConfig
|
BaseConfig <|-- TrainConfig
|
||||||
BaseConfig <|-- InputConfig
|
|
||||||
BaseConfig <|-- ProcessingConfig
|
|
||||||
BaseConfig <|-- OutputConfig
|
|
||||||
BaseConfig <|-- PipelineConfig
|
|
||||||
BaseModelConfig <|-- AutoRegressiveLMConfig
|
BaseModelConfig <|-- AutoRegressiveLMConfig
|
||||||
BaseModelConfig <|-- EncoderConfig
|
BaseModelConfig <|-- EncoderConfig
|
||||||
BaseFactory <|-- AutoModel
|
BaseFactory <|-- AutoModel
|
||||||
|
|
@ -1054,23 +867,18 @@ classDiagram
|
||||||
BaseFactory <|-- StrategyFactory
|
BaseFactory <|-- StrategyFactory
|
||||||
BaseFactory <|-- SchedulerFactory
|
BaseFactory <|-- SchedulerFactory
|
||||||
BaseFactory <|-- CallbackFactory
|
BaseFactory <|-- CallbackFactory
|
||||||
BaseFactory <|-- StoreFactory
|
BaseFactory <|-- StorageFactory
|
||||||
BaseFactory <|-- ExecutorFactory
|
|
||||||
BaseFactory <|-- ConfigFactory
|
BaseFactory <|-- ConfigFactory
|
||||||
BaseFactory <|-- MaskBuilderFactory
|
TrainCallback <|-- ValidationCallback
|
||||||
BaseExecutor <|-- NoneExecutor
|
ProtocolHandler <|-- OpenAIHandler
|
||||||
BaseExecutor <|-- DDPExecutor
|
ProtocolHandler <|-- AnthropicHandler
|
||||||
BaseExecutor <|-- FSDPExecutor
|
|
||||||
ResponseBuilder <|-- OpenAIResponseBuilder
|
|
||||||
ResponseBuilder <|-- AnthropicResponseBuilder
|
|
||||||
BaseMaskBuilder <|-- ChatMaskBuilder
|
|
||||||
BaseMaskBuilder <|-- InstructionMaskBuilder
|
|
||||||
BaseMaskBuilder <|-- TextMaskBuilder
|
|
||||||
|
|
||||||
%% --- Composition (strong ownership, part destroyed with whole) ---
|
%% --- Composition (strong ownership, part destroyed with whole) ---
|
||||||
KVCache *-- PagePool
|
KVCache *-- PagePool
|
||||||
KVCache *-- Storage
|
KVCache *-- Storage
|
||||||
KVCache *-- TaskTable
|
KVCache *-- TaskTable
|
||||||
|
PagePool *-- Allocator
|
||||||
|
PagePool *-- PrefixCache
|
||||||
InferenceEngine *-- InferenceScheduler
|
InferenceEngine *-- InferenceScheduler
|
||||||
InferenceScheduler *-- KVCache
|
InferenceScheduler *-- KVCache
|
||||||
InferenceScheduler *-- Executor
|
InferenceScheduler *-- Executor
|
||||||
|
|
@ -1084,31 +892,21 @@ classDiagram
|
||||||
DecoderBlock *-- RMSNorm
|
DecoderBlock *-- RMSNorm
|
||||||
ChatCompletionRequest *-- ChatMessage
|
ChatCompletionRequest *-- ChatMessage
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
|
AutoTokenizer *-- ChatTemplate
|
||||||
BaseFactory *-- Registry
|
BaseFactory *-- Registry
|
||||||
BaseExecutor *-- GradientState
|
|
||||||
AccumOptimizer o-- GradientState
|
|
||||||
AccumScheduler o-- GradientState
|
|
||||||
|
|
||||||
%% --- Aggregation (weak ownership) ---
|
%% --- Aggregation (weak ownership) ---
|
||||||
AutoModel o-- BaseModelConfig
|
AutoModel o-- BaseModelConfig
|
||||||
AutoTokenizer o-- ChatTemplate
|
|
||||||
PagePool o-- Allocator
|
|
||||||
PagePool o-- PrefixCache
|
|
||||||
Trainer o-- TrainCallback
|
Trainer o-- TrainCallback
|
||||||
TrainContext o-- BaseStrategy
|
TrainContext o-- BaseStrategy
|
||||||
TrainContext o-- BaseScheduler
|
TrainContext o-- BaseScheduler
|
||||||
TrainContext o-- Checkpoint
|
TrainContext o-- Checkpoint
|
||||||
TrainContext o-- BaseExecutor
|
|
||||||
KvcacheView o-- Storage
|
KvcacheView o-- Storage
|
||||||
SamplingPipeline o-- BaseSamplingStrategy
|
SamplingPipeline o-- BaseSamplingStrategy
|
||||||
BaseDataset o-- Store
|
BaseDataset o-- BaseStorage
|
||||||
Pipeline o-- PipelineConfig
|
|
||||||
Pipeline o-- BaseMaskBuilder
|
|
||||||
|
|
||||||
%% --- Dependency (uses temporarily) ---
|
%% --- Dependency (uses temporarily) ---
|
||||||
TrainConfig ..> BaseStrategy : selects
|
TrainConfig ..> BaseStrategy : selects
|
||||||
PipelineConfig ..> MaskBuilderFactory : selects
|
|
||||||
MaskBuilderFactory ..> BaseMaskBuilder : creates
|
|
||||||
StrategyFactory ..> BaseStrategy : creates
|
StrategyFactory ..> BaseStrategy : creates
|
||||||
SchedulerFactory ..> BaseScheduler : creates
|
SchedulerFactory ..> BaseScheduler : creates
|
||||||
DatasetFactory ..> BaseDataset : creates
|
DatasetFactory ..> BaseDataset : creates
|
||||||
|
|
@ -1119,14 +917,10 @@ classDiagram
|
||||||
FFNFactory ..> DeepSeekMoE : creates
|
FFNFactory ..> DeepSeekMoE : creates
|
||||||
DecoderBlock ..> AttnFactory : uses
|
DecoderBlock ..> AttnFactory : uses
|
||||||
DecoderBlock ..> FFNFactory : uses
|
DecoderBlock ..> FFNFactory : uses
|
||||||
StoreFactory ..> H5Store : creates
|
StorageFactory ..> H5Storage : creates
|
||||||
StoreFactory ..> MmapStore : creates
|
StorageFactory ..> JSONStorage : creates
|
||||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||||
ConfigFactory ..> EncoderConfig : creates
|
ConfigFactory ..> EncoderConfig : creates
|
||||||
ExecutorFactory ..> NoneExecutor : creates
|
|
||||||
ExecutorFactory ..> DDPExecutor : creates
|
|
||||||
ExecutorFactory ..> FSDPExecutor : creates
|
|
||||||
TrainContextBuilder ..> ExecutorFactory : creates
|
|
||||||
Trainer ..> TrainContextBuilder : uses
|
Trainer ..> TrainContextBuilder : uses
|
||||||
TrainContextBuilder ..> TrainContext : creates
|
TrainContextBuilder ..> TrainContext : creates
|
||||||
Trainer ..> Functions : spawns
|
Trainer ..> Functions : spawns
|
||||||
|
|
@ -1137,10 +931,10 @@ classDiagram
|
||||||
KVCache ..> KvcacheView : binds
|
KVCache ..> KvcacheView : binds
|
||||||
InferenceEngine ..> GenerationRequest : uses
|
InferenceEngine ..> GenerationRequest : uses
|
||||||
InferenceEngine ..> GenerateResult : creates
|
InferenceEngine ..> GenerateResult : creates
|
||||||
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
|
OpenAIHandler ..> ChatCompletionRequest : receives
|
||||||
AnthropicResponseBuilder ..> MessagesRequest : receives
|
AnthropicHandler ..> MessagesRequest : receives
|
||||||
ProtocolHandler ..> StopChecker : creates
|
ProtocolHandler ..> StopChecker : creates
|
||||||
ProtocolHandler ..> GenContext : creates
|
ProtocolHandler ..> StreamContext : creates
|
||||||
|
|
||||||
%% --- Association (general usage) ---
|
%% --- Association (general usage) ---
|
||||||
Trainer --> TrainConfig
|
Trainer --> TrainConfig
|
||||||
|
|
@ -1153,6 +947,8 @@ classDiagram
|
||||||
Executor --> AutoModel
|
Executor --> AutoModel
|
||||||
Executor --> AutoTokenizer
|
Executor --> AutoTokenizer
|
||||||
TaskManager --> AutoTokenizer
|
TaskManager --> AutoTokenizer
|
||||||
|
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||||
|
ResumableDistributedSampler --> BaseDataset
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -1161,48 +957,43 @@ classDiagram
|
||||||
|
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||||
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, ChatMaskBuilder, InstructionMaskBuilder, TextMaskBuilder, Pipeline, filter_by_length, dedup_signature | Declarative JSON-driven data preprocessing |
|
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||||
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
|
|
||||||
|
|
||||||
## Design Patterns
|
## Design Patterns
|
||||||
|
|
||||||
| Pattern | Classes | Purpose |
|
| Pattern | Classes | Purpose |
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||||
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context |
|
| **Builder** | `TrainContextBuilder` | Chain-building training context |
|
||||||
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
||||||
| **Context** | `TrainContext` | Unified training state bag |
|
| **Context** | `TrainContext` | Unified training state bag |
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor`, `FSDPExecutor` | Gradient accumulation & model distribution |
|
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||||
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||||
|
|
||||||
## Core Relationships
|
## Core Relationships
|
||||||
|
|
||||||
1. **Config → Training**: `TrainConfig` holds `model_fn`, `dataset`, `optimizer_fn`, `scheduler_fn`, `parallel_mode`, `executor_kwargs`
|
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
||||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
|
||||||
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
|
|
@ -5,21 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
```
|
```
|
||||||
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
|
Raw Text → AutoTokenizer → Token IDs → .h5/.json → Dataset → Sampler → DataLoader → Training/Inference
|
||||||
```
|
```
|
||||||
|
|
||||||
## Data Preparation
|
## Data Preparation
|
||||||
|
|
||||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or JSON (`.json`/`.jsonl`) files with keyed tensor groups.
|
||||||
|
|
||||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
|
|
||||||
```
|
```
|
||||||
StoreFactory.create("h5") → H5Store
|
StorageFactory.create("h5") → H5Storage
|
||||||
StoreFactory.create("bin") → MmapStore
|
StorageFactory.create("json") → JSONStorage
|
||||||
```
|
```
|
||||||
|
|
||||||
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
Both support shared memory via `.share_memory_()`.
|
||||||
|
|
||||||
## Data Keys by Training Type
|
## Data Keys by Training Type
|
||||||
|
|
||||||
|
|
@ -33,21 +33,14 @@ H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS pag
|
||||||
## Dataset Architecture
|
## Dataset Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
|
DatasetFactory.load(train_type, path, window_size, stride)
|
||||||
→ BaseDataset.load(load_path, storage_type=None)
|
→ StorageFactory.create(detect_format(path))
|
||||||
→ detect_format(load_path)
|
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||||
→ StoreFactory.create(storage_type)
|
→ BaseDataset.__getitem__(idx)
|
||||||
→ Store.load(load_path)
|
→ sliding window [begin, end) via get_index(idx)
|
||||||
→ H5Store._normalize() / MmapStore._normalize()
|
|
||||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
|
||||||
→ BaseDataset.__getitem__(idx)
|
|
||||||
→ get_index(idx) → [begin, end)
|
|
||||||
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
|
`window_size` = max input length, `stride` = step between consecutive samples.
|
||||||
|
|
||||||
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
|
|
||||||
|
|
||||||
## Sampler
|
## Sampler
|
||||||
|
|
||||||
|
|
@ -61,4 +54,4 @@ DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_typ
|
||||||
|
|
||||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
|
|
@ -12,16 +12,16 @@ RoPE is applied **before** KV cache write, not after — otherwise position enco
|
||||||
|
|
||||||
## KVCache System
|
## KVCache System
|
||||||
|
|
||||||
Six classes (plus two helpers) working together:
|
Six classes working together:
|
||||||
|
|
||||||
```
|
```
|
||||||
KVCache (facade)
|
KVCache (facade)
|
||||||
├── PagePool orchestrates page allocation + prefix matching
|
├── Allocator bitmask-based page allocator + ref-count + LRU eviction
|
||||||
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
├── PrefixCache hash-based prefix matching (page_hash via rolling hash)
|
||||||
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
├── PagePool orchestrates Allocator + PrefixCache
|
||||||
├── TaskTable maps task_id → page_table + cached token count
|
├── TaskTable maps task_id → page_table + cached token count
|
||||||
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
||||||
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
└── KvcacheView bundles Storage + page_table + total_len for attention layers
|
||||||
```
|
```
|
||||||
|
|
||||||
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
||||||
|
|
@ -40,33 +40,26 @@ KVCache (facade)
|
||||||
## Sampling (Strategy Pattern)
|
## Sampling (Strategy Pattern)
|
||||||
|
|
||||||
```
|
```
|
||||||
BaseSamplingStrategy (ABC)
|
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
||||||
├── TemperatureStrategy
|
|
||||||
├── TopKStrategy
|
|
||||||
├── TopPStrategy
|
|
||||||
└── SamplingPipeline
|
|
||||||
```
|
```
|
||||||
|
|
||||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||||
`sample()` is a convenience shortcut for one-shot usage.
|
`sample()` is a convenience shortcut for one-shot usage.
|
||||||
|
|
||||||
## Protocol Handlers (Strategy Pattern)
|
## Protocol Handlers (Template Method)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ProtocolHandler: # concrete orchestrator
|
class ProtocolHandler(ABC):
|
||||||
def __init__(self, request, engine, builder): ...
|
def handle(self):
|
||||||
async def handle(self):
|
ctx = StreamContext(...)
|
||||||
prompt, ctx, stops = builder.prepare(request, engine)
|
|
||||||
agen = engine.generate_async(prompt, ...)
|
agen = engine.generate_async(prompt, ...)
|
||||||
if stream: self._handle_stream(agen, ctx, stops)
|
if stream: self._handle_stream(agen, ctx)
|
||||||
else: return await self._handle_non_stream(agen, ctx, stops)
|
else: self._handle_non_stream(agen, ctx)
|
||||||
```
|
```
|
||||||
|
|
||||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
|
||||||
|
|
||||||
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
`OpenAIHandler` → `/v1/chat/completions`, `AnthropicHandler` → `/v1/messages`.
|
||||||
|
|
||||||
Adding a protocol = one builder file, no handler subclassing needed.
|
|
||||||
|
|
||||||
## Engine & GenerateResult
|
## Engine & GenerateResult
|
||||||
|
|
||||||
|
|
@ -74,9 +67,7 @@ Adding a protocol = one builder file, no handler subclassing needed.
|
||||||
InferenceEngine
|
InferenceEngine
|
||||||
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
├── generate(prompt, stream, ...) → str | List[str] | Generator
|
||||||
├── generate_with_request(req) → same
|
├── generate_with_request(req) → same
|
||||||
├── generate_async(prompt, ...) → AsyncGenerator
|
└── generate_async(prompt, ...) → AsyncGenerator
|
||||||
├── get_stats() → Dict
|
|
||||||
└── shutdown()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
|
||||||
|
|
@ -103,14 +94,12 @@ Response:
|
||||||
{
|
{
|
||||||
"id": "chatcmpl-abc123",
|
"id": "chatcmpl-abc123",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"created": 1717000000,
|
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||||
"model": "astrai",
|
|
||||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
|
||||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]`
|
||||||
|
|
||||||
### Anthropic
|
### Anthropic
|
||||||
|
|
||||||
|
|
@ -127,10 +116,10 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||||
| Param | Type | Default | Description |
|
| Param | Type | Default | Description |
|
||||||
|-------|------|---------|-------------|
|
|-------|------|---------|-------------|
|
||||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||||
| `top_k` | int | 50 | Top-k count |
|
| `temperature` | float | 1.0 | Sampling temperature (0.0–2.0) |
|
||||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
|
| `top_k` | int | 50 | Top-k count |
|
||||||
| `max_tokens` | Optional[int] | None | Max generation length |
|
| `max_tokens` | int | None | Max generation length |
|
||||||
| `stream` | bool | False | Stream output |
|
| `stream` | bool | False | Stream output |
|
||||||
|
|
||||||
## Engine API
|
## Engine API
|
||||||
|
|
@ -145,8 +134,7 @@ engine.generate("Hello", stream=True) # -> Generator[str]
|
||||||
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
|
|
||||||
# Async
|
# Async
|
||||||
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
|
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||||
print(token)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
|
|
@ -53,9 +53,7 @@
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||||
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
|
||||||
| `--device_type` | Device type | cuda |
|
| `--device_type` | Device type | cuda |
|
||||||
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
|
||||||
|
|
||||||
### Strategy-specific
|
### Strategy-specific
|
||||||
|
|
||||||
|
|
@ -75,7 +73,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--parallel_mode=ddp \
|
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
@ -97,4 +94,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
> Document Update Time: 2026-05-24
|
> Document Update Time: 2026-05-17
|
||||||
|
|
@ -1,346 +0,0 @@
|
||||||
# Preprocessing Pipeline
|
|
||||||
|
|
||||||
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
|
|
||||||
|
|
||||||
## Philosophy
|
|
||||||
|
|
||||||
| Component | Responsibility |
|
|
||||||
|-----------|---------------|
|
|
||||||
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
|
||||||
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
|
||||||
|
|
||||||
A single config file captures the entire pipeline, reusable and version-controllable.
|
|
||||||
|
|
||||||
## Config Structure
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"input": {}, // sections (single) or sources (multi)
|
|
||||||
"mask": {}, // role → "train" | "mask"
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {},
|
|
||||||
"output": {}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Section Fields
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `field` | str | -- | JSONL key to read |
|
|
||||||
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
|
|
||||||
| `template` | bool | `false` | Apply `chat_template` per message |
|
|
||||||
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
|
|
||||||
|
|
||||||
### Source Fields (multi-output mode)
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `sections` | list[dict] | -- | Same as single-output section list |
|
|
||||||
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
|
|
||||||
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### SFT Chat
|
|
||||||
|
|
||||||
Input JSONL:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
Config:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "messages", "action": "$role", "template": true}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"mask": {
|
|
||||||
"system": "mask",
|
|
||||||
"user": "mask",
|
|
||||||
"assistant": "train"
|
|
||||||
},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"storage_format": "bin",
|
|
||||||
"dtype": {"loss_mask": "bool"}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Output keys: `sequence` (int32), `loss_mask` (bool)
|
|
||||||
|
|
||||||
### SFT Instruction
|
|
||||||
|
|
||||||
Input JSONL:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
|
||||||
```
|
|
||||||
|
|
||||||
Config:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
|
||||||
{"field": "response", "action": "train"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Output keys: `sequence`, `loss_mask`
|
|
||||||
|
|
||||||
### Pretrain
|
|
||||||
|
|
||||||
Input JSONL:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{"text": "Artificial Intelligence is a field of computer science..."}
|
|
||||||
```
|
|
||||||
|
|
||||||
Config:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "text", "action": "train"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 8192,
|
|
||||||
"min_chars": 100
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Output keys: `sequence` (no `loss_mask` — all tokens trained)
|
|
||||||
|
|
||||||
### DPO
|
|
||||||
|
|
||||||
Input JSONL:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
Config:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"sources": {
|
|
||||||
"chosen": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "chosen", "action": "$role", "template": true}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"rejected": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "rejected", "action": "$role", "template": true}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"mask": {
|
|
||||||
"user": "mask",
|
|
||||||
"assistant": "train"
|
|
||||||
},
|
|
||||||
"mask_default": "mask"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
|
|
||||||
|
|
||||||
### GRPO
|
|
||||||
|
|
||||||
Input JSONL:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
|
|
||||||
```
|
|
||||||
|
|
||||||
Config:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"sources": {
|
|
||||||
"prompts": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "prompt", "action": "mask", "template": true}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "responses", "action": "train"}
|
|
||||||
],
|
|
||||||
"list_field": true,
|
|
||||||
"mask_key": "masks"
|
|
||||||
},
|
|
||||||
"rewards": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "rewards", "action": "value"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"mask": {
|
|
||||||
"user": "mask",
|
|
||||||
"assistant": "train"
|
|
||||||
},
|
|
||||||
"mask_default": "mask"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
|
|
||||||
|
|
||||||
- `action: "value"` — extract raw values from JSONL without tokenisation
|
|
||||||
- `list_field: true` — tokenise each list element independently, then concatenate
|
|
||||||
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Configuration Reference
|
|
||||||
|
|
||||||
### `input`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
|
|
||||||
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
|
|
||||||
|
|
||||||
When `sources` is set, `sections` is ignored.
|
|
||||||
|
|
||||||
### `mask`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
|
|
||||||
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
|
|
||||||
|
|
||||||
### `preprocessing`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
|
|
||||||
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
|
|
||||||
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
|
|
||||||
| `max_items` | int or null | `null` | Stop after N documents |
|
|
||||||
|
|
||||||
### `output`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
|
|
||||||
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
|
|
||||||
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
|
|
||||||
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Mask Algorithm
|
|
||||||
|
|
||||||
### Template mode (`template: true`)
|
|
||||||
|
|
||||||
For each message in the field's array:
|
|
||||||
|
|
||||||
1. Prepend BOS token (masked)
|
|
||||||
2. Render through `chat_template` for that single message
|
|
||||||
3. Encode rendered text
|
|
||||||
4. Apply mask rule for the message's role
|
|
||||||
|
|
||||||
### Non-template mode
|
|
||||||
|
|
||||||
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
|
|
||||||
|
|
||||||
### Text config detection
|
|
||||||
|
|
||||||
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Output Layout
|
|
||||||
|
|
||||||
### Single-Shard (`bin`)
|
|
||||||
|
|
||||||
```
|
|
||||||
output/
|
|
||||||
__default__/
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
wiki/
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multi-Shard (`bin`)
|
|
||||||
|
|
||||||
When `max_tokens_per_shard` is exceeded:
|
|
||||||
|
|
||||||
```
|
|
||||||
output/
|
|
||||||
__default__/
|
|
||||||
shard_0000/
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
shard_0001/
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
```
|
|
||||||
|
|
||||||
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## CLI
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# SFT
|
|
||||||
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
|
|
||||||
|
|
||||||
# DPO
|
|
||||||
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
|
|
||||||
|
|
||||||
# GRPO
|
|
||||||
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Python API
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
|
||||||
|
|
||||||
config = PipelineConfig.from_json("sft.json")
|
|
||||||
Pipeline(
|
|
||||||
config,
|
|
||||||
["data_part1.jsonl", "data_part2.jsonl"],
|
|
||||||
output_dir="output/",
|
|
||||||
tokenizer_path="params",
|
|
||||||
).run()
|
|
||||||
```
|
|
||||||
|
|
||||||
> Document Update Time: 2026-06-03
|
|
||||||
|
|
@ -1,5 +1,38 @@
|
||||||
# Training
|
# Training
|
||||||
|
|
||||||
|
## Model Architecture
|
||||||
|
|
||||||
|
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, Chinese–English bilingual.
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TB
|
||||||
|
subgraph Layers["Transformer Layers"]
|
||||||
|
direction TB
|
||||||
|
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
||||||
|
B --> C[Transformer Block\nLayer ...]
|
||||||
|
C --> D[Transformer Block\nLayer ...]
|
||||||
|
D --> E[RMSNorm]
|
||||||
|
E --> F[Linear]
|
||||||
|
F --> G[SoftMax]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph TransformerBlock["Transformer Block"]
|
||||||
|
direction TB
|
||||||
|
H[x] --> I[RMSNorm]
|
||||||
|
I --> J[Linear → Q/K/V]
|
||||||
|
J --> K[Q]; J --> L[K]; J --> M[V]
|
||||||
|
K --> N[RoPE]; L --> O[RoPE]
|
||||||
|
N --> P["Q @ K^T / sqrt(d)"]; O --> P
|
||||||
|
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
|
||||||
|
R --> S[Linear]; S --> T[+]; H --> T
|
||||||
|
T --> U[RMSNorm]
|
||||||
|
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
|
||||||
|
V --> X[SiLU]; X --> Y[×]; W --> Y
|
||||||
|
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
|
||||||
|
AA --> BB[x']
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
### Autoregression
|
### Autoregression
|
||||||
|
|
||||||
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||||||
|
|
@ -36,24 +69,20 @@ Two-level loop: **epoch** → **batch**. Optimizer step fires every `grad_accum_
|
||||||
|
|
||||||
```
|
```
|
||||||
on_train_begin
|
on_train_begin
|
||||||
model.train()
|
|
||||||
on_epoch_begin
|
on_epoch_begin
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
on_batch_begin
|
on_batch_begin
|
||||||
with executor.accumulate(model):
|
loss = strategy(batch)
|
||||||
loss = strategy.compute_loss(batch)
|
(loss / grad_accum_steps).backward()
|
||||||
context.loss = loss.item()
|
iteration += 1
|
||||||
stand_loss = loss / executor.grad_accum_steps
|
on_batch_end
|
||||||
executor.backward(stand_loss)
|
|
||||||
context.iteration += 1
|
|
||||||
on_batch_end
|
|
||||||
|
|
||||||
if executor.sync_gradients:
|
if iteration % grad_accum_steps == 0:
|
||||||
on_optimizer_step
|
on_step_begin
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if scheduler:
|
on_step_end
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
```
|
```
|
||||||
|
|
@ -63,15 +92,12 @@ on_train_end
|
||||||
| Hook | Fires | Default callback |
|
| Hook | Fires | Default callback |
|
||||||
|------|-------|-----------------|
|
|------|-------|-----------------|
|
||||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||||
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
|
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||||
| `on_batch_begin` | Every batch | — |
|
|
||||||
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
|
|
||||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||||
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
|
| `on_step_end` | Every accumulation window | `ValidationCallback` |
|
||||||
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
|
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||||
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
|
|
||||||
|
|
||||||
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||||
|
|
||||||
## Strategies
|
## Strategies
|
||||||
|
|
||||||
|
|
@ -83,7 +109,7 @@ $$
|
||||||
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
|
Keys: `input_ids`, `target_ids`
|
||||||
|
|
||||||
### SFT (Supervised Fine-Tuning)
|
### SFT (Supervised Fine-Tuning)
|
||||||
|
|
||||||
|
|
@ -93,7 +119,7 @@ $$
|
||||||
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
|
Keys: `input_ids`, `target_ids`, `loss_mask`
|
||||||
|
|
||||||
### DPO (Direct Preference Optimization)
|
### DPO (Direct Preference Optimization)
|
||||||
|
|
||||||
|
|
@ -103,7 +129,7 @@ $$
|
||||||
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
Parameters: `beta=0.1`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
|
||||||
|
|
||||||
### GRPO (Group Relative Policy Optimization)
|
### GRPO (Group Relative Policy Optimization)
|
||||||
|
|
||||||
|
|
@ -117,7 +143,7 @@ $$
|
||||||
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
|
||||||
$$
|
$$
|
||||||
|
|
||||||
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
|
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`.
|
||||||
|
|
||||||
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
|
|
||||||
|
|
@ -128,7 +154,7 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||||
|
|
||||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||||||
|
|
||||||
## Gradient Checkpointing
|
## Gradient Checkpointing
|
||||||
|
|
||||||
|
|
@ -144,30 +170,29 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
||||||
```
|
```
|
||||||
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
||||||
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
|
||||||
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
|
└── load(save_dir) broadcasts metadata from rank-0
|
||||||
```
|
```
|
||||||
|
|
||||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||||
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
|
||||||
|
|
||||||
## TrainContextBuilder (Builder Pattern)
|
## TrainContextBuilder (Builder Pattern)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
context = (
|
context = (
|
||||||
TrainContextBuilder(config)
|
TrainContextBuilder(config)
|
||||||
.with_resume_dir(resume_dir)
|
.with_checkpoint(checkpoint)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
- Loads checkpoint weights if provided
|
- Loads checkpoint weights if provided
|
||||||
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
- Wraps model with `parallel_wrapper` if `nprocs > 1`
|
||||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
|
||||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||||
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
|
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
||||||
|
|
||||||
## Training CLI
|
## Training CLI
|
||||||
|
|
||||||
|
|
@ -176,7 +201,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||||
|
|
||||||
nohup python scripts/tools/train.py \
|
nohup python scripts/tools/train.py \
|
||||||
--nprocs=4 \
|
--nprocs=4 \
|
||||||
--parallel_mode=ddp \
|
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/model \
|
--param_path=/path/to/model \
|
||||||
|
|
@ -198,4 +222,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
> Document Update Time: 2026-05-17
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = "1.3.7"
|
__version__ = "1.3.6"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,13 @@ from astrai.config.model_config import (
|
||||||
ConfigFactory,
|
ConfigFactory,
|
||||||
EncoderConfig,
|
EncoderConfig,
|
||||||
)
|
)
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
OutputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Model configuration
|
||||||
"BaseModelConfig",
|
"BaseModelConfig",
|
||||||
"AutoRegressiveLMConfig",
|
"AutoRegressiveLMConfig",
|
||||||
"EncoderConfig",
|
"EncoderConfig",
|
||||||
"ConfigFactory",
|
"ConfigFactory",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"InputConfig",
|
|
||||||
"OutputConfig",
|
|
||||||
"PipelineConfig",
|
|
||||||
"ProcessingConfig",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
from dataclasses import MISSING, dataclass, fields
|
from dataclasses import MISSING, dataclass, fields
|
||||||
from pathlib import Path
|
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||||
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -14,21 +13,12 @@ class BaseConfig:
|
||||||
d[fld.name] = v
|
d[fld.name] = v
|
||||||
elif v is None:
|
elif v is None:
|
||||||
d[fld.name] = None
|
d[fld.name] = None
|
||||||
elif isinstance(v, (dict, list, tuple)):
|
elif isinstance(v, (dict, list)):
|
||||||
try:
|
try:
|
||||||
val = list(v) if isinstance(v, tuple) else v
|
json.dumps(v)
|
||||||
json.dumps(val)
|
d[fld.name] = v
|
||||||
d[fld.name] = val
|
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
pass
|
pass
|
||||||
elif isinstance(v, BaseConfig):
|
|
||||||
d[fld.name] = v.to_dict()
|
|
||||||
elif hasattr(v, "__dataclass_fields__"):
|
|
||||||
sub = {}
|
|
||||||
for f in fields(v):
|
|
||||||
a = getattr(v, f.name)
|
|
||||||
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
|
||||||
d[fld.name] = sub
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -84,15 +74,4 @@ class BaseConfig:
|
||||||
return value
|
return value
|
||||||
if isinstance(value, target_type):
|
if isinstance(value, target_type):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
|
||||||
return target_type.from_dict(value)
|
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
return cls.from_dict(json.load(f))
|
|
||||||
|
|
||||||
def to_json(self, path: Union[str, Path]):
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,6 @@ class AutoRegressiveLMConfig(BaseModelConfig):
|
||||||
|
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
rope_scaling: Optional[dict] = None
|
|
||||||
|
|
||||||
attn_type: str = "gqa"
|
attn_type: str = "gqa"
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
|
|
@ -81,7 +80,6 @@ class EncoderConfig(BaseModelConfig):
|
||||||
|
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
rope_scaling: Optional[dict] = None
|
|
||||||
|
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
|
|
|
||||||
|
|
@ -1,109 +0,0 @@
|
||||||
"""Pipeline configuration for JSONL preprocessing.
|
|
||||||
|
|
||||||
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
|
|
||||||
modes, both driven declaratively through ``input.sections`` or
|
|
||||||
``input.sources``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InputConfig(BaseConfig):
|
|
||||||
"""Declarative input mapping.
|
|
||||||
|
|
||||||
Single-output mode (backward-compatible)::
|
|
||||||
|
|
||||||
{"input": {"sections": [{"field": "messages", ...}]}}
|
|
||||||
|
|
||||||
Multi-output mode (DPO / GRPO)::
|
|
||||||
|
|
||||||
{"input": {"sources": {
|
|
||||||
"chosen": {"sections": [{"field": "chosen", ...}]},
|
|
||||||
"rejected": {"sections": [{"field": "rejected", ...}]},
|
|
||||||
}}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
sections: Optional[List[Dict]] = None
|
|
||||||
sources: Optional[Dict[str, Dict]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProcessingConfig(BaseConfig):
|
|
||||||
"""Processing configuration.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
max_seq_len : int
|
|
||||||
Maximum sequence length (default: 2048).
|
|
||||||
min_chars : int
|
|
||||||
Minimum number of characters to keep (default: 50).
|
|
||||||
max_chars : int
|
|
||||||
Maximum number of characters to keep (default: 2_000_000).
|
|
||||||
max_items : Optional[int]
|
|
||||||
Maximum number of items to process (default: None, unlimited).
|
|
||||||
packing_strategy : str
|
|
||||||
How to pack sequences into a contiguous stream.
|
|
||||||
|
|
||||||
- ``"simple"``: sequential concatenation (default, backward compatible).
|
|
||||||
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
|
|
||||||
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
|
|
||||||
max_packed_len : int
|
|
||||||
Maximum length of a packed bin. Sequences longer than this are
|
|
||||||
truncated or split depending on ``packing_strategy`` (default: 8192).
|
|
||||||
truncation_mode : str
|
|
||||||
How to truncate sequences longer than ``max_packed_len``.
|
|
||||||
|
|
||||||
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
|
|
||||||
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_seq_len: int = 2048
|
|
||||||
min_chars: int = 50
|
|
||||||
max_chars: int = 2_000_000
|
|
||||||
max_items: Optional[int] = None
|
|
||||||
packing_strategy: str = "simple"
|
|
||||||
max_packed_len: int = 8192
|
|
||||||
truncation_mode: str = "keep_start"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OutputConfig(BaseConfig):
|
|
||||||
"""Output configuration.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
domain_key : Optional[str]
|
|
||||||
Domain key for the output store (default: None).
|
|
||||||
storage_format : str
|
|
||||||
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
|
|
||||||
max_tokens_per_shard : int
|
|
||||||
Maximum tokens per shard before splitting (default: 100_000_000).
|
|
||||||
dtype : Dict[str, str]
|
|
||||||
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
|
|
||||||
position_ids_mode : Optional[str]
|
|
||||||
How to compute position_ids in packed sequences.
|
|
||||||
|
|
||||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
|
||||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
|
||||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
|
||||||
"""
|
|
||||||
|
|
||||||
domain_key: Optional[str] = None
|
|
||||||
storage_format: str = "bin"
|
|
||||||
max_tokens_per_shard: int = 100_000_000
|
|
||||||
dtype: Dict[str, str] = field(default_factory=dict)
|
|
||||||
position_ids_mode: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineConfig(BaseConfig):
|
|
||||||
version: int = 1
|
|
||||||
input: InputConfig = field(default_factory=InputConfig)
|
|
||||||
mask: Dict[str, str] = field(default_factory=dict)
|
|
||||||
mask_default: str = "mask"
|
|
||||||
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
|
||||||
output: OutputConfig = field(default_factory=OutputConfig)
|
|
||||||
|
|
@ -7,7 +7,6 @@ from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
from astrai.config.base import BaseConfig
|
||||||
from astrai.model.components.lora import LoRAConfig
|
|
||||||
|
|
||||||
|
|
||||||
def required(**kw):
|
def required(**kw):
|
||||||
|
|
@ -17,8 +16,8 @@ def required(**kw):
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig(BaseConfig):
|
class TrainConfig(BaseConfig):
|
||||||
# basic setting
|
# basic setting
|
||||||
model_fn: Callable[[], nn.Module] = field(
|
model: nn.Module = field(
|
||||||
default=None, metadata=required(help="Model factory for training.")
|
default=None, metadata=required(help="Model for training.")
|
||||||
)
|
)
|
||||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||||
dataset: Dataset = field(
|
dataset: Dataset = field(
|
||||||
|
|
@ -57,12 +56,6 @@ class TrainConfig(BaseConfig):
|
||||||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||||
)
|
)
|
||||||
|
|
||||||
# lora setting
|
|
||||||
lora: Optional[LoRAConfig] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "LoRA config. None means full fine-tuning."},
|
|
||||||
)
|
|
||||||
|
|
||||||
# metric setting
|
# metric setting
|
||||||
log_dir: str = field(
|
log_dir: str = field(
|
||||||
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
||||||
|
|
@ -102,9 +95,11 @@ class TrainConfig(BaseConfig):
|
||||||
master_port: str = field(
|
master_port: str = field(
|
||||||
default="29500", metadata={"help": "Master port for distributed training."}
|
default="29500", metadata={"help": "Master port for distributed training."}
|
||||||
)
|
)
|
||||||
parallel_mode: str = field(
|
parallel_wrapper: Optional[Callable] = field(
|
||||||
default="none",
|
default=None, metadata={"help": "Parallel function for training."}
|
||||||
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
)
|
||||||
|
state_dict_fn: Optional[Callable] = field(
|
||||||
|
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||||
)
|
)
|
||||||
start_method: str = field(
|
start_method: str = field(
|
||||||
default="spawn",
|
default="spawn",
|
||||||
|
|
@ -118,21 +113,11 @@ class TrainConfig(BaseConfig):
|
||||||
val_dataset: Optional[Dataset] = field(
|
val_dataset: Optional[Dataset] = field(
|
||||||
default=None, metadata={"help": "Dataset for validation."}
|
default=None, metadata={"help": "Dataset for validation."}
|
||||||
)
|
)
|
||||||
val_split: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
val_step: int = field(
|
val_step: int = field(
|
||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||||
)
|
)
|
||||||
|
|
||||||
executor_kwargs: dict = field(
|
|
||||||
default_factory=dict,
|
|
||||||
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
|
||||||
)
|
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict, metadata={"help": "Other arguments."}
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,28 +4,32 @@ from astrai.dataset.dataset import (
|
||||||
)
|
)
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
BaseSegmentFetcher,
|
||||||
MmapStore,
|
BaseStorage,
|
||||||
Store,
|
H5Storage,
|
||||||
StoreFactory,
|
JSONStorage,
|
||||||
|
MultiSegmentFetcher,
|
||||||
|
StorageFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_bin,
|
|
||||||
load_h5,
|
load_h5,
|
||||||
save_bin,
|
load_json,
|
||||||
save_h5,
|
save_h5,
|
||||||
|
save_json,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
"Store",
|
"BaseSegmentFetcher",
|
||||||
"StoreFactory",
|
"MultiSegmentFetcher",
|
||||||
"H5Store",
|
"BaseStorage",
|
||||||
"MmapStore",
|
"H5Storage",
|
||||||
|
"JSONStorage",
|
||||||
|
"StorageFactory",
|
||||||
"detect_format",
|
"detect_format",
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
"save_bin",
|
"save_json",
|
||||||
"load_bin",
|
"load_json",
|
||||||
"ResumableDistributedSampler",
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
Store,
|
BaseStorage,
|
||||||
StoreFactory,
|
StorageFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.storage: Optional[Store] = None
|
self.storage: Optional[BaseStorage] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
|
|
@ -48,26 +48,37 @@ class BaseDataset(Dataset, ABC):
|
||||||
f"Missing: {missing}"
|
f"Missing: {missing}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(self, load_path: str, storage_type: Optional[str] = None):
|
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
||||||
"""Load dataset from the given path.
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
Auto-detects the storage format if not specified.
|
Auto-detects the storage format if not specified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
load_path: Path to the data directory or file
|
load_path: Path to the data directory or file
|
||||||
storage_type: Force a specific storage type ("h5", "bin"),
|
storage_type: Force a specific storage type ("h5", "json"),
|
||||||
or None for auto-detection
|
or None for auto-detection
|
||||||
|
tokenizer: Callable str -> List[int], used to tokenize raw text
|
||||||
|
in JSON files. Ignored for HDF5.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: If the loaded storage is missing required keys.
|
KeyError: If the loaded storage is missing required keys.
|
||||||
"""
|
"""
|
||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = StoreFactory.create(storage_type)
|
self.storage = StorageFactory.create(storage_type)
|
||||||
self._load_path = load_path
|
self._load_path = load_path
|
||||||
self.storage.load(load_path)
|
self.storage.load(load_path, tokenizer=tokenizer)
|
||||||
self._validate_keys()
|
self._validate_keys()
|
||||||
|
|
||||||
|
def load_json(self, load_path: str, tokenizer=None):
|
||||||
|
"""Load dataset from JSON files explicitly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_path: Path to the JSON data file or directory
|
||||||
|
tokenizer: Optional tokenizer callable for raw text JSON.
|
||||||
|
"""
|
||||||
|
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||||
|
|
@ -137,7 +148,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, dataset_cls: type):
|
def _validate_component(cls, dataset_cls: type) -> None:
|
||||||
"""Validate that the dataset class inherits from BaseDataset."""
|
"""Validate that the dataset class inherits from BaseDataset."""
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
if not issubclass(dataset_cls, BaseDataset):
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||||
|
|
@ -164,6 +175,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: Optional[int] = None,
|
stride: Optional[int] = None,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
|
tokenizer=None,
|
||||||
) -> "BaseDataset":
|
) -> "BaseDataset":
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
|
|
@ -172,7 +184,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
load_path: Path to the data file
|
load_path: Path to the data file
|
||||||
window_size: Window size for data sampling
|
window_size: Window size for data sampling
|
||||||
stride: Stride between consecutive samples (default: same as window_size)
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
storage_type: Storage type ("h5", "json") or None for auto-detection
|
||||||
|
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded dataset instance
|
Loaded dataset instance
|
||||||
|
|
@ -181,7 +194,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
stride = window_size
|
stride = window_size
|
||||||
|
|
||||||
dataset = cls.create(train_type, window_size, stride)
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
dataset.load(load_path, storage_type=storage_type)
|
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
@ -223,7 +236,7 @@ class SFTDataset(BaseDataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
return ["sequence", "loss_mask", "position_ids"]
|
return ["sequence", "loss_mask"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
@ -231,17 +244,15 @@ class SFTDataset(BaseDataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
x = self._fetch_data(begin_idx, end_idx, "sequence")
|
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
|
||||||
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
|
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
|
||||||
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
|
dtype=torch.long
|
||||||
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
|
)
|
||||||
|
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
|
||||||
|
dtype=torch.bool
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
|
||||||
"input_ids": x.to(dtype=torch.long),
|
|
||||||
"target_ids": y.to(dtype=torch.long),
|
|
||||||
"position_ids": position_ids.to(dtype=torch.long),
|
|
||||||
"loss_mask": loss_mask.to(dtype=torch.bool),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("dpo")
|
@DatasetFactory.register("dpo")
|
||||||
|
|
@ -295,11 +306,9 @@ class GRPODataset(BaseDataset):
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
|
||||||
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
||||||
dtype=torch.long
|
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
||||||
)
|
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,6 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
self.iter = self.iter % self.num_samples_per_replica
|
|
||||||
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
|
@ -75,10 +74,5 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
@property
|
|
||||||
def _remaining(self):
|
|
||||||
remaining = self.num_samples_per_replica - self.iter
|
|
||||||
return max(remaining, 0)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._remaining
|
return self.num_samples_per_replica
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,17 @@
|
||||||
"""Storage backends for different data formats.
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
Layers:
|
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
||||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
a uniform interface for data access and length observation via fetchers.
|
||||||
return Dict[str, List[Tensor]] — format-specific, no state
|
|
||||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
|
||||||
Dict[str, List[Tensor]] per key via _normalize(),
|
|
||||||
fetch() uses bisect across segments — no forced concat
|
|
||||||
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
|
||||||
|
|
||||||
Key properties:
|
|
||||||
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
|
||||||
datasets larger than RAM
|
|
||||||
- Explicit length: _length = min(total elements across keys), set at load,
|
|
||||||
__len__ returns O(1)
|
|
||||||
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
|
||||||
workers share OS page-cache pages
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
import glob
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Union
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
@ -69,30 +54,54 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
return tensor_group
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
meta = {}
|
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
||||||
|
json_data = {}
|
||||||
for key, tensors in tensor_group.items():
|
for key, tensors in tensor_group.items():
|
||||||
cat = torch.cat(tensors, dim=0)
|
json_data[key] = [tensor.tolist() for tensor in tensors]
|
||||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
with open(full_file_path, "w", encoding="utf-8") as f:
|
||||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
json.dump(json_data, f, ensure_ascii=False)
|
||||||
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
|
||||||
json.dump(meta, f)
|
|
||||||
|
|
||||||
|
|
||||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
def load_json(
|
||||||
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
file_path: str,
|
||||||
meta = json.load(f)
|
share_memory: bool = True,
|
||||||
segments: Dict[str, List[Tensor]] = {}
|
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
||||||
for key, info in meta.items():
|
) -> Dict[str, List[Tensor]]:
|
||||||
arr = np.memmap(
|
"""Load tensor data from JSON files.
|
||||||
os.path.join(file_path, f"{key}.bin"),
|
|
||||||
dtype=info["dtype"],
|
Supports two modes:
|
||||||
mode="r+",
|
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
||||||
shape=tuple(info["shape"]),
|
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
||||||
)
|
at load time. A ``tokenizer`` receives a str and returns List[int].
|
||||||
segments[key] = [torch.from_numpy(arr)]
|
|
||||||
return segments
|
Non-data JSON files (e.g. config.json) with scalar/object values are
|
||||||
|
silently skipped.
|
||||||
|
"""
|
||||||
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
root_path = Path(file_path)
|
||||||
|
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
||||||
|
for json_file in json_files:
|
||||||
|
with open(json_file, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
for key, sequences in data.items():
|
||||||
|
if not isinstance(sequences, list):
|
||||||
|
continue
|
||||||
|
tensors = []
|
||||||
|
for seq in sequences:
|
||||||
|
if tokenizer is not None and isinstance(seq, str):
|
||||||
|
seq = tokenizer(seq)
|
||||||
|
tensor = torch.tensor(seq, dtype=torch.long)
|
||||||
|
if share_memory:
|
||||||
|
tensor = tensor.share_memory_()
|
||||||
|
tensors.append(tensor)
|
||||||
|
if tensor_group.get(key) is None:
|
||||||
|
tensor_group[key] = []
|
||||||
|
tensor_group[key].extend(tensors)
|
||||||
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
def detect_format(load_path: str) -> str:
|
def detect_format(load_path: str) -> str:
|
||||||
|
|
@ -102,7 +111,7 @@ def detect_format(load_path: str) -> str:
|
||||||
load_path: Directory or file path
|
load_path: Directory or file path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Format string ("h5" or "bin")
|
Format string ("h5" or "json")
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If no supported data files are found
|
FileNotFoundError: If no supported data files are found
|
||||||
|
|
@ -112,160 +121,181 @@ def detect_format(load_path: str) -> str:
|
||||||
suffix = root.suffix.lower()
|
suffix = root.suffix.lower()
|
||||||
if suffix in (".h5", ".hdf5"):
|
if suffix in (".h5", ".hdf5"):
|
||||||
return "h5"
|
return "h5"
|
||||||
|
if suffix in (".json", ".jsonl"):
|
||||||
|
return "json"
|
||||||
raise ValueError(f"Unsupported file format: {suffix}")
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
h5_files = [
|
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||||
Path(p)
|
|
||||||
for pattern in ("*.h5", "*.hdf5")
|
|
||||||
for p in glob.glob(str(root / "**" / pattern), recursive=True)
|
|
||||||
]
|
|
||||||
if h5_files:
|
if h5_files:
|
||||||
return "h5"
|
return "h5"
|
||||||
bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
|
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||||
if bin_files:
|
if json_files:
|
||||||
has_meta = (root / "meta.json").exists() or len(
|
return "json"
|
||||||
[Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
|
|
||||||
) > 0
|
|
||||||
if has_meta:
|
|
||||||
return "bin"
|
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
||||||
class Store(ABC):
|
class BaseSegmentFetcher:
|
||||||
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
"""Fetches data segments across multiple tensor segments.
|
||||||
|
|
||||||
Each key maps to one or more tensor segments (no forced concatenation).
|
Maintains cumulative lengths for efficient range queries across
|
||||||
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
multiple discontinuous segments.
|
||||||
total element count across all keys.
|
"""
|
||||||
|
|
||||||
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
def __init__(self, segments: List[Tensor]):
|
||||||
via ``_normalize()``.
|
self.segments = segments
|
||||||
|
self.cum_lengths = []
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for seg in segments:
|
||||||
|
total += torch.numel(seg)
|
||||||
|
self.cum_lengths.append(total)
|
||||||
|
|
||||||
|
self.total_length = total
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.total_length
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
"""Fetch data in the range [begin_idx, end_idx)."""
|
||||||
|
if not (
|
||||||
|
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||||
|
):
|
||||||
|
raise ValueError("begin_idx or end_idx out of bounds")
|
||||||
|
if begin_idx >= end_idx:
|
||||||
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||||
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||||
|
|
||||||
|
result_segments = []
|
||||||
|
|
||||||
|
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||||
|
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||||
|
start = max(begin_idx - prev_cum, 0)
|
||||||
|
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||||
|
result_segments.append(self.segments[i][start:end])
|
||||||
|
|
||||||
|
return torch.cat(result_segments, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSegmentFetcher:
|
||||||
|
"""Manages multiple segment fetchers for different data keys."""
|
||||||
|
|
||||||
|
def __init__(self, multi_segments: Dict):
|
||||||
|
self.multi_keys = list(multi_segments.keys())
|
||||||
|
self.multi_fetchers = {
|
||||||
|
key: BaseSegmentFetcher(segments)
|
||||||
|
for key, segments in multi_segments.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Returns the minimum length across all fetchers."""
|
||||||
|
if not self.multi_fetchers:
|
||||||
|
return 0
|
||||||
|
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||||
|
return min(len_list)
|
||||||
|
|
||||||
|
def key_fetch(
|
||||||
|
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||||
|
) -> Dict:
|
||||||
|
"""Fetch data for specific keys."""
|
||||||
|
fetch_dict = {}
|
||||||
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
fetcher = self.multi_fetchers[key]
|
||||||
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||||
|
fetch_dict[key] = fetch_tensor
|
||||||
|
|
||||||
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||||
|
"""Fetch all keys."""
|
||||||
|
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStorage(ABC):
|
||||||
|
"""Abstract storage backend for loading and dispatching data.
|
||||||
|
|
||||||
|
Storage encapsulates format-specific loading and provides a uniform
|
||||||
|
interface for data access and length observation. Subclasses handle
|
||||||
|
different data formats (HDF5, JSON, etc.) while exposing the same
|
||||||
|
fetch interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._data: Dict[str, List[Tensor]] = {}
|
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||||
self._cum: Dict[str, List[int]] = {}
|
|
||||||
self._length: int = 0
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, path: str) -> None:
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
|
"""Load data from the given path into internal fetcher."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Total number of raw elements (tokens) in storage."""
|
||||||
|
if self._fetcher is None:
|
||||||
|
return 0
|
||||||
|
return len(self._fetcher)
|
||||||
|
|
||||||
|
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
||||||
|
"""Fetch data for the given keys and index range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_idx: Starting index (inclusive)
|
||||||
|
end_idx: Ending index (exclusive)
|
||||||
|
keys: Single key or list of keys to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor if single key, Dict[str, Tensor] if multiple keys
|
||||||
|
"""
|
||||||
|
if self._fetcher is None:
|
||||||
|
raise RuntimeError("Storage not loaded")
|
||||||
|
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self) -> List[str]:
|
def keys(self) -> List[str]:
|
||||||
return list(self._data.keys())
|
"""Return the data keys available in this storage."""
|
||||||
|
if self._fetcher is None:
|
||||||
def __len__(self) -> int:
|
return []
|
||||||
return self._length
|
return self._fetcher.multi_keys
|
||||||
|
|
||||||
def fetch(
|
|
||||||
self,
|
|
||||||
begin: int,
|
|
||||||
end: int,
|
|
||||||
keys: Union[str, List[str]],
|
|
||||||
):
|
|
||||||
if not self._data:
|
|
||||||
raise RuntimeError("Store not loaded")
|
|
||||||
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
|
||||||
raise ValueError(
|
|
||||||
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
|
||||||
)
|
|
||||||
if isinstance(keys, str):
|
|
||||||
return self._fetch_key(keys, begin, end)
|
|
||||||
return {k: self._fetch_key(k, begin, end) for k in keys}
|
|
||||||
|
|
||||||
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
|
||||||
"""Fetch slice [begin, end) across potentially multiple segments."""
|
|
||||||
segments = self._data[key]
|
|
||||||
cum = self._cum[key]
|
|
||||||
seg_start = bisect.bisect_right(cum, begin)
|
|
||||||
seg_end = bisect.bisect_left(cum, end)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for i in range(seg_start, seg_end + 1):
|
|
||||||
prev = cum[i - 1] if i > 0 else 0
|
|
||||||
s = max(begin - prev, 0)
|
|
||||||
e = min(end - prev, segments[i].shape[0])
|
|
||||||
results.append(segments[i][s:e])
|
|
||||||
|
|
||||||
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
|
||||||
|
|
||||||
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
|
||||||
"""Register segments and pre-compute cumulative lengths.
|
|
||||||
|
|
||||||
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
|
||||||
large datasets. Sets ``self._length`` to the minimum total
|
|
||||||
element count across all keys.
|
|
||||||
"""
|
|
||||||
for key, tensors in raw.items():
|
|
||||||
self._data[key] = tensors
|
|
||||||
cum = []
|
|
||||||
total = 0
|
|
||||||
for t in tensors:
|
|
||||||
total += t.shape[0]
|
|
||||||
cum.append(total)
|
|
||||||
self._cum[key] = cum
|
|
||||||
self._length = (
|
|
||||||
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
|
||||||
if self._cum
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StoreFactory(BaseFactory["Store"]):
|
class StorageFactory(BaseFactory["BaseStorage"]):
|
||||||
"""Factory for creating Store instances by type name.
|
"""Factory for creating storage backends by type name.
|
||||||
|
|
||||||
Example::
|
Example:
|
||||||
|
@StorageFactory.register("custom")
|
||||||
@StoreFactory.register("custom")
|
class CustomStorage(BaseStorage):
|
||||||
class CustomStore(Store):
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
storage = StorageFactory.create("custom")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, store_cls: type):
|
def _validate_component(cls, storage_cls: type) -> None:
|
||||||
if not issubclass(store_cls, Store):
|
if not issubclass(storage_cls, BaseStorage):
|
||||||
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("h5")
|
@StorageFactory.register("h5")
|
||||||
class H5Store(Store):
|
class H5Storage(BaseStorage):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
self._normalize(load_h5(path))
|
segments = load_h5(load_path)
|
||||||
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("bin")
|
@StorageFactory.register("json")
|
||||||
class MmapStore(Store):
|
class JSONStorage(BaseStorage):
|
||||||
"""Memory-mapped binary storage backend.
|
"""JSON-based storage backend.
|
||||||
|
|
||||||
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
Supports two modes:
|
||||||
No per-process memory duplication — all DataLoader workers share the
|
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
||||||
same OS page-cache pages.
|
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
||||||
|
callable (str -> List[int]) at load time.
|
||||||
Format on disk::
|
|
||||||
|
|
||||||
data_root/
|
|
||||||
meta.json # {key: {shape, dtype}, ...}
|
|
||||||
<key>.bin # raw numpy array, one per key
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
self._mmap_refs = []
|
segments = load_json(load_path, tokenizer=tokenizer)
|
||||||
root = Path(path)
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
all_raw: Dict[str, List[Tensor]] = {}
|
|
||||||
meta_paths = [
|
|
||||||
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
|
|
||||||
]
|
|
||||||
for meta_path in meta_paths:
|
|
||||||
raw = load_bin(str(meta_path.parent))
|
|
||||||
for key, tensors in raw.items():
|
|
||||||
if key not in all_raw:
|
|
||||||
all_raw[key] = []
|
|
||||||
all_raw[key].extend(tensors)
|
|
||||||
if not meta_paths:
|
|
||||||
raise FileNotFoundError(f"No meta.json found under {path}")
|
|
||||||
self._normalize(all_raw)
|
|
||||||
for tensors in self._data.values():
|
|
||||||
self._mmap_refs.extend(tensors)
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class Registry:
|
||||||
component_cls: Type,
|
component_cls: Type,
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
):
|
) -> None:
|
||||||
"""Register a component class with optional category and priority."""
|
"""Register a component class with optional category and priority."""
|
||||||
if name in self._entries:
|
if name in self._entries:
|
||||||
raise ValueError(f"Component '{name}' is already registered")
|
raise ValueError(f"Component '{name}' is already registered")
|
||||||
|
|
@ -158,7 +158,7 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
return component_cls(*args, **kwargs)
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, component_cls: Type[T]):
|
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||||
"""Validate that the component class is valid for this factory.
|
"""Validate that the component class is valid for this factory.
|
||||||
|
|
||||||
Override this method in subclasses to add custom validation.
|
Override this method in subclasses to add custom validation.
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,25 @@
|
||||||
"""Inference module for continuous batching.
|
"""Inference module for continuous batching.
|
||||||
|
|
||||||
Layers:
|
Layers:
|
||||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||||
- api/: HTTP orchestration (ProtocolHandler, server)
|
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||||
- protocols/: Response builders (OpenAI, Anthropic)
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||||
- transport/: SSE transport utilities
|
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
|
||||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from astrai.inference.api import (
|
from astrai.inference.api import (
|
||||||
|
AnthropicHandler,
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
GenContext,
|
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
|
OpenAIHandler,
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
StopChecker,
|
StopChecker,
|
||||||
get_app,
|
StreamContext,
|
||||||
|
app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.core import (
|
from astrai.inference.core import (
|
||||||
STOP,
|
STOP,
|
||||||
Allocator,
|
Allocator,
|
||||||
|
|
@ -38,7 +36,10 @@ from astrai.inference.core import (
|
||||||
TaskTable,
|
TaskTable,
|
||||||
page_hash,
|
page_hash,
|
||||||
)
|
)
|
||||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
from astrai.inference.engine import (
|
||||||
|
GenerationRequest,
|
||||||
|
InferenceEngine,
|
||||||
|
)
|
||||||
from astrai.inference.sample import (
|
from astrai.inference.sample import (
|
||||||
BaseSamplingStrategy,
|
BaseSamplingStrategy,
|
||||||
SamplingPipeline,
|
SamplingPipeline,
|
||||||
|
|
@ -49,14 +50,17 @@ from astrai.inference.sample import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Engine / Requests
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
|
# Core scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Executor",
|
"Executor",
|
||||||
"STOP",
|
"STOP",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskManager",
|
"TaskManager",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
|
# Core cache
|
||||||
"Allocator",
|
"Allocator",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"KvcacheView",
|
"KvcacheView",
|
||||||
|
|
@ -65,21 +69,24 @@ __all__ = [
|
||||||
"Storage",
|
"Storage",
|
||||||
"TaskTable",
|
"TaskTable",
|
||||||
"page_hash",
|
"page_hash",
|
||||||
|
# Sampling (Strategy pattern)
|
||||||
"sample",
|
"sample",
|
||||||
"BaseSamplingStrategy",
|
"BaseSamplingStrategy",
|
||||||
"TemperatureStrategy",
|
"TemperatureStrategy",
|
||||||
"TopKStrategy",
|
"TopKStrategy",
|
||||||
"TopPStrategy",
|
"TopPStrategy",
|
||||||
"SamplingPipeline",
|
"SamplingPipeline",
|
||||||
|
# Protocol
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"GenContext",
|
"StreamContext",
|
||||||
"OpenAIResponseBuilder",
|
"AnthropicHandler",
|
||||||
"AnthropicResponseBuilder",
|
"OpenAIHandler",
|
||||||
|
# Server
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"get_app",
|
"app",
|
||||||
"run_server",
|
"run_server",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,31 @@
|
||||||
"""Inference API: protocol handler, stop checker, and FastAPI server.
|
"""Inference API: protocol handlers and FastAPI server."""
|
||||||
|
|
||||||
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
from astrai.inference.api.protocol import (
|
||||||
lazy singleton FastAPI instance.
|
AnthropicHandler,
|
||||||
"""
|
OpenAIHandler,
|
||||||
|
ProtocolHandler,
|
||||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
StopChecker,
|
||||||
|
StreamContext,
|
||||||
|
)
|
||||||
from astrai.inference.api.server import (
|
from astrai.inference.api.server import (
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
get_app,
|
app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AnthropicHandler",
|
||||||
|
"OpenAIHandler",
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"GenContext",
|
"StreamContext",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"get_app",
|
"app",
|
||||||
"run_server",
|
"run_server",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,141 +0,0 @@
|
||||||
"""Anthropic message completion response builder."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
|
||||||
GenContext,
|
|
||||||
ResponseBuilder,
|
|
||||||
StopInfo,
|
|
||||||
sse_event,
|
|
||||||
)
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicResponseBuilder(ResponseBuilder):
|
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
messages: List[Dict[str, str]] = []
|
|
||||||
system = getattr(request, "system", None)
|
|
||||||
if system:
|
|
||||||
messages.append({"role": "system", "content": system})
|
|
||||||
for m in request.messages:
|
|
||||||
text = _extract_text(m.content)
|
|
||||||
if text:
|
|
||||||
messages.append({"role": m.role, "content": text})
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
ctx = GenContext(
|
|
||||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
|
||||||
created=int(time.time()),
|
|
||||||
model=request.model,
|
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
|
||||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
|
||||||
return prompt, ctx, stop_sequences
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [],
|
|
||||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
event="message_start",
|
|
||||||
),
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
event="content_block_start",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
return sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": token},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
events: List[str] = []
|
|
||||||
if stop.matched:
|
|
||||||
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
|
||||||
unyielded = trimmed[len(stop.yielded) :]
|
|
||||||
if unyielded:
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": unyielded},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{"type": "content_block_stop", "index": 0},
|
|
||||||
event="content_block_stop",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {
|
|
||||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
||||||
"stop_sequence": stop.matched,
|
|
||||||
},
|
|
||||||
"usage": {"output_tokens": ctx.completion_tokens},
|
|
||||||
},
|
|
||||||
event="message_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
|
||||||
return events
|
|
||||||
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
if stop.matched:
|
|
||||||
content = content[: content.rfind(stop.matched)]
|
|
||||||
return {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
||||||
"stop_sequence": stop.matched,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": ctx.prompt_tokens,
|
|
||||||
"output_tokens": ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
"""OpenAI chat completion response builder."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
|
||||||
GenContext,
|
|
||||||
ResponseBuilder,
|
|
||||||
StopInfo,
|
|
||||||
sse_event,
|
|
||||||
)
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_UNSUPPORTED_PARAMS = (
|
|
||||||
"n",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
"logit_bias",
|
|
||||||
"user",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponseBuilder(ResponseBuilder):
|
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
||||||
self._model = request.model
|
|
||||||
|
|
||||||
for param in _UNSUPPORTED_PARAMS:
|
|
||||||
value = getattr(request, param, None)
|
|
||||||
fields = getattr(type(request), "model_fields", {})
|
|
||||||
default = fields[param].default if param in fields else None
|
|
||||||
if value is not None and value != default:
|
|
||||||
logger.warning(
|
|
||||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
|
||||||
param,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
if value is not None and value != default:
|
|
||||||
logger.warning(
|
|
||||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
|
||||||
param,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = GenContext(
|
|
||||||
resp_id=self._resp_id,
|
|
||||||
created=int(time.time()),
|
|
||||||
model=self._model,
|
|
||||||
prompt_tokens=0,
|
|
||||||
)
|
|
||||||
stop = request.stop
|
|
||||||
stop_sequences = (
|
|
||||||
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
||||||
)
|
|
||||||
return prompt, ctx, stop_sequences
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant"},
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
return sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 0,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
return [
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
||||||
}
|
|
||||||
),
|
|
||||||
sse_event(
|
|
||||||
{
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"id": self._resp_id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": self._model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": content},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||||||
|
|
||||||
ProtocolHandler orchestrates the async generation loop and delegates
|
Template Method + Builder patterns eliminate the 45% code duplication between
|
||||||
protocol-specific formatting to a ResponseBuilder.
|
stream/non-stream branches and across protocol adapters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -15,7 +17,7 @@ from pydantic import BaseModel
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
lines: List[str] = []
|
lines: List[str] = []
|
||||||
if event:
|
if event:
|
||||||
lines.append(f"event: {event}")
|
lines.append(f"event: {event}")
|
||||||
|
|
@ -24,28 +26,22 @@ def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def sse_done() -> str:
|
def _sse_done() -> str:
|
||||||
return "data: [DONE]\n\n"
|
return "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenContext:
|
class StreamContext:
|
||||||
"""Per-generation metadata passed to builder format methods."""
|
"""Shared state across the streaming generation lifecycle."""
|
||||||
|
|
||||||
resp_id: str
|
resp_id: str
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
|
accumulated: str = ""
|
||||||
|
stop_matched: Optional[str] = None
|
||||||
@dataclass
|
last_yield_trimmed: str = ""
|
||||||
class StopInfo:
|
|
||||||
"""Stop-check result passed to format_stream_end / format_response."""
|
|
||||||
|
|
||||||
matched: Optional[str] = None
|
|
||||||
body: str = ""
|
|
||||||
yielded: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class StopChecker:
|
class StopChecker:
|
||||||
|
|
@ -60,60 +56,95 @@ class StopChecker:
|
||||||
return seq
|
return seq
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def trim(self, text: str, matched: str) -> str:
|
||||||
|
idx = text.rfind(matched)
|
||||||
|
return text[:idx] if idx != -1 else text
|
||||||
|
|
||||||
class ResponseBuilder(ABC):
|
@property
|
||||||
"""Interface for protocol-specific response formatting.
|
def has_sequences(self) -> bool:
|
||||||
|
return len(self._sequences) > 0
|
||||||
|
|
||||||
A new protocol requires one concrete builder implementing 5 methods.
|
|
||||||
|
class ProtocolHandler(ABC):
|
||||||
|
"""Template-method base for API protocol handlers.
|
||||||
|
|
||||||
|
Subclasses implement format hooks; the base class orchestrates the
|
||||||
|
generate-async loop and SSE/JSON response construction.
|
||||||
|
|
||||||
|
Lifecycle::
|
||||||
|
|
||||||
|
handle()
|
||||||
|
├─ build_prompt() # protocol-specific prompt assembly
|
||||||
|
├─ create_response_id() # unique response identifier
|
||||||
|
├─ [stream]
|
||||||
|
│ ├─ format_stream_start()
|
||||||
|
│ ├─ format_stream_token() × N
|
||||||
|
│ │ └─ on_token() hook for stop-sequence interception
|
||||||
|
│ └─ format_stream_end()
|
||||||
|
└─ [non-stream]
|
||||||
|
├─ (accumulate tokens)
|
||||||
|
└─ format_non_stream_response()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
request_model: type[BaseModel]
|
||||||
def prepare(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
|
||||||
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
|
||||||
|
|
||||||
@abstractmethod
|
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
||||||
"""SSE events that open the stream."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_chunk(self, token: str) -> str:
|
|
||||||
"""SSE event for a single generated token."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
||||||
"""SSE events that close the stream."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_response(
|
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""JSON response body for non-streaming mode."""
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolHandler:
|
|
||||||
"""Orchestrates the generation loop, delegates formatting to a builder.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
|
||||||
response = await handler.handle()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
|
||||||
):
|
|
||||||
self.request = request
|
self.request = request
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.builder = builder
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
"""Build the full prompt string from the request messages."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
"""Generate a unique response ID following the protocol convention."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
"""Yield an SSE event for a single generated token."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build the JSON response body for non-streaming mode."""
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def create_stop_checker(self) -> StopChecker:
|
||||||
|
return StopChecker(self.get_stop_sequences())
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Hook after each token is appended to accumulated.
|
||||||
|
|
||||||
|
Return a matched stop-sequence string to break the loop,
|
||||||
|
or None to continue.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||||
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
ctx = StreamContext(
|
||||||
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
resp_id=self.create_response_id(),
|
||||||
|
created=int(time.time()),
|
||||||
|
model=self.request.model,
|
||||||
|
prompt_tokens=self._count_prompt_tokens(),
|
||||||
|
)
|
||||||
|
|
||||||
agen = self.engine.generate_async(
|
agen = self.engine.generate_async(
|
||||||
prompt=prompt,
|
prompt=self.build_prompt(),
|
||||||
max_tokens=self.request.max_tokens,
|
max_tokens=self.request.max_tokens,
|
||||||
temperature=self.request.temperature,
|
temperature=self.request.temperature,
|
||||||
top_p=self.request.top_p,
|
top_p=self.request.top_p,
|
||||||
|
|
@ -121,37 +152,33 @@ class ProtocolHandler:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.request.stream:
|
if self.request.stream:
|
||||||
return self._handle_stream(agen, ctx, stop_sequences)
|
return self._handle_stream(agen, ctx)
|
||||||
else:
|
else:
|
||||||
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
return await self._handle_non_stream(agen, ctx)
|
||||||
|
|
||||||
def _handle_stream(
|
def _count_prompt_tokens(self) -> int:
|
||||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||||||
) -> StreamingResponse:
|
|
||||||
checker = StopChecker(stop_sequences)
|
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||||||
|
stop_checker = self.create_stop_checker()
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
for event in self.builder.format_stream_start(ctx):
|
for event in self.format_stream_start(ctx):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
body = ""
|
|
||||||
yielded = ""
|
|
||||||
matched = None
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
body += token
|
ctx.completion_tokens += 1
|
||||||
|
ctx.accumulated += token
|
||||||
|
|
||||||
matched = checker.check(body)
|
matched = self.on_token(ctx, token, stop_checker)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
yield self.format_stream_token(ctx, token)
|
||||||
yield self.builder.format_chunk(token)
|
|
||||||
yielded += token
|
|
||||||
|
|
||||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
for event in self.format_stream_end(ctx):
|
||||||
for event in self.builder.format_stream_end(ctx, stop):
|
|
||||||
yield event
|
yield event
|
||||||
yield sse_done()
|
yield _sse_done()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_stream(),
|
event_stream(),
|
||||||
|
|
@ -159,24 +186,260 @@ class ProtocolHandler:
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_non_stream(
|
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
stop_checker = self.create_stop_checker()
|
||||||
) -> Dict[str, Any]:
|
|
||||||
checker = StopChecker(stop_sequences)
|
|
||||||
chunks: List[str] = []
|
chunks: List[str] = []
|
||||||
body = ""
|
|
||||||
matched = None
|
|
||||||
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
|
ctx.completion_tokens += 1
|
||||||
|
ctx.accumulated += token
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
body += token
|
|
||||||
|
|
||||||
matched = checker.check(body)
|
matched = self.on_token(ctx, token, stop_checker)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
|
||||||
|
|
||||||
content = "".join(chunks)
|
content = "".join(chunks)
|
||||||
stop = StopInfo(matched=matched, body=body)
|
return self.format_non_stream_response(ctx, content)
|
||||||
return self.builder.format_response(ctx, content, stop)
|
|
||||||
|
|
||||||
|
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
|
"""Extract plain text from an Anthropic content block (string or list)."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIHandler(ProtocolHandler):
|
||||||
|
"""OpenAI-compatible /v1/chat/completions handler."""
|
||||||
|
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
messages = [
|
||||||
|
{"role": m.role, "content": m.content} for m in self.request.messages
|
||||||
|
]
|
||||||
|
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
stop = self.request.stop
|
||||||
|
if stop is None:
|
||||||
|
return []
|
||||||
|
return [stop] if isinstance(stop, str) else stop
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
return stop_checker.check(ctx.accumulated)
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant"},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
return _sse_event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicHandler(ProtocolHandler):
|
||||||
|
"""Anthropic-compatible /v1/messages handler."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._yielded = ""
|
||||||
|
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
messages: List[Dict[str, str]] = []
|
||||||
|
system = getattr(self.request, "system", None)
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
for m in self.request.messages:
|
||||||
|
content = _extract_text_content(m.content)
|
||||||
|
if content:
|
||||||
|
messages.append({"role": m.role, "content": content})
|
||||||
|
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
return f"msg_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
return getattr(self.request, "stop_sequences", None) or []
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
matched = stop_checker.check(ctx.accumulated)
|
||||||
|
if not matched:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ctx.stop_matched = matched
|
||||||
|
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||||
|
unyielded = trimmed[len(self._yielded) :]
|
||||||
|
if unyielded:
|
||||||
|
ctx.last_yield_trimmed = unyielded
|
||||||
|
return matched
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [],
|
||||||
|
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
event="message_start",
|
||||||
|
),
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": {"type": "text", "text": ""},
|
||||||
|
},
|
||||||
|
event="content_block_start",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
self._yielded += token
|
||||||
|
return _sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": token},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
matched = ctx.stop_matched
|
||||||
|
events: List[str] = []
|
||||||
|
last_yielded = ctx.last_yield_trimmed
|
||||||
|
if last_yielded:
|
||||||
|
events.append(
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": last_yielded},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
_sse_event(
|
||||||
|
{"type": "content_block_stop", "index": 0},
|
||||||
|
event="content_block_stop",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
_sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": {
|
||||||
|
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||||
|
"stop_sequence": matched,
|
||||||
|
},
|
||||||
|
"usage": {"output_tokens": ctx.completion_tokens},
|
||||||
|
},
|
||||||
|
event="message_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
matched = ctx.stop_matched
|
||||||
|
if matched:
|
||||||
|
content = content[: content.rfind(matched)]
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [{"type": "text", "text": content}],
|
||||||
|
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||||
|
"stop_sequence": matched,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": ctx.prompt_tokens,
|
||||||
|
"output_tokens": ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,6 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
|
||||||
|
|
||||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||||
|
|
||||||
``app`` is lazily constructed — importing this module does NOT create a FastAPI instance.
|
|
||||||
Use :func:`get_app` to access the singleton.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -15,19 +12,17 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import APIRouter, FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.api.protocol import ProtocolHandler
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_app_instance: Optional[FastAPI] = None
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
|
@ -87,15 +82,17 @@ async def lifespan(app: FastAPI):
|
||||||
logger.info("Inference engine shutdown complete")
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
def _create_engine(
|
def _create_engine(
|
||||||
param_path: Path,
|
param_path: Optional[Path] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
) -> InferenceEngine:
|
) -> InferenceEngine:
|
||||||
|
if param_path is None:
|
||||||
|
param_path = _project_root / "params"
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
|
|
@ -113,66 +110,49 @@ def _create_engine(
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def get_app() -> FastAPI:
|
|
||||||
"""Return the singleton FastAPI instance (lazily created on first call)."""
|
|
||||||
global _app_instance
|
|
||||||
if _app_instance is None:
|
|
||||||
_app_instance = FastAPI(
|
|
||||||
title="AstrAI Inference Server",
|
|
||||||
version="0.2.0",
|
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
|
||||||
_app_instance.include_router(router)
|
|
||||||
_app_instance.state.server_config = {}
|
|
||||||
_app_instance.state.engine = None
|
|
||||||
return _app_instance
|
|
||||||
|
|
||||||
|
|
||||||
def _get_engine() -> InferenceEngine:
|
def _get_engine() -> InferenceEngine:
|
||||||
engine = get_app().state.engine
|
engine = app.state.engine
|
||||||
if engine is None:
|
if engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
app = get_app()
|
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"model_loaded": app.state.engine is not None,
|
"model_loaded": app.state.engine is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats")
|
@app.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats():
|
||||||
return _get_engine().get_stats()
|
return _get_engine().get_stats()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
handler = OpenAIHandler(request, engine)
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/messages")
|
@app.post("/v1/messages")
|
||||||
async def create_message(request: MessagesRequest):
|
async def create_message(request: MessagesRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
handler = AnthropicHandler(request, engine)
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
param_path: Path,
|
|
||||||
host: str = "0.0.0.0",
|
host: str = "0.0.0.0",
|
||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
reload: bool = False,
|
reload: bool = False,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
param_path: Optional[Path] = None,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
app = get_app()
|
|
||||||
app.state.server_config = {
|
app.state.server_config = {
|
||||||
"device": device,
|
"device": device,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class Allocator:
|
||||||
return idx
|
return idx
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def free(self, idx: int, keep_cached: bool = False):
|
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._refs[idx] -= 1
|
self._refs[idx] -= 1
|
||||||
if self._refs[idx] == 0:
|
if self._refs[idx] == 0:
|
||||||
|
|
@ -51,7 +51,7 @@ class Allocator:
|
||||||
else:
|
else:
|
||||||
self._free_mask |= 1 << idx
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
def inc_ref(self, idx: int):
|
def inc_ref(self, idx: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._refs[idx] += 1
|
self._refs[idx] += 1
|
||||||
self._lru.pop(idx, None)
|
self._lru.pop(idx, None)
|
||||||
|
|
@ -60,7 +60,7 @@ class Allocator:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._refs[idx]
|
return self._refs[idx]
|
||||||
|
|
||||||
def touch(self, idx: int):
|
def touch(self, idx: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._lru.move_to_end(idx)
|
self._lru.move_to_end(idx)
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ class PrefixCache:
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def evict(self, idx: int):
|
def evict(self, idx: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
h = self._page_to_hash.pop(idx, None)
|
h = self._page_to_hash.pop(idx, None)
|
||||||
if h is not None:
|
if h is not None:
|
||||||
|
|
@ -96,7 +96,9 @@ class PrefixCache:
|
||||||
hits.append(p)
|
hits.append(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
def record(
|
||||||
|
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||||
|
) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||||
old_h = self._page_to_hash.pop(page_idx, None)
|
old_h = self._page_to_hash.pop(page_idx, None)
|
||||||
|
|
@ -125,13 +127,13 @@ class PagePool:
|
||||||
def alloc(self) -> int:
|
def alloc(self) -> int:
|
||||||
return self._alloc.alloc()
|
return self._alloc.alloc()
|
||||||
|
|
||||||
def free(self, idx: int):
|
def free(self, idx: int) -> None:
|
||||||
keep = self._prefix.has_page(idx)
|
keep = self._prefix.has_page(idx)
|
||||||
self._alloc.free(idx, keep_cached=keep)
|
self._alloc.free(idx, keep_cached=keep)
|
||||||
if not keep:
|
if not keep:
|
||||||
self._prefix.evict(idx)
|
self._prefix.evict(idx)
|
||||||
|
|
||||||
def inc_ref(self, idx: int):
|
def inc_ref(self, idx: int) -> None:
|
||||||
self._alloc.inc_ref(idx)
|
self._alloc.inc_ref(idx)
|
||||||
|
|
||||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||||
|
|
@ -140,7 +142,9 @@ class PagePool:
|
||||||
self._alloc.touch(p)
|
self._alloc.touch(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
|
def record(
|
||||||
|
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||||
|
) -> None:
|
||||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -153,7 +157,7 @@ class TaskTable:
|
||||||
self._cached: Dict[str, int] = {}
|
self._cached: Dict[str, int] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def set(self, task_id: str, page_table: List[int], cached: int):
|
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._pages[task_id] = page_table
|
self._pages[task_id] = page_table
|
||||||
self._cached[task_id] = cached
|
self._cached[task_id] = cached
|
||||||
|
|
@ -216,7 +220,7 @@ class Storage:
|
||||||
start_pos: int,
|
start_pos: int,
|
||||||
k: Tensor,
|
k: Tensor,
|
||||||
v: Tensor,
|
v: Tensor,
|
||||||
):
|
) -> None:
|
||||||
seq_len = k.size(1)
|
seq_len = k.size(1)
|
||||||
if seq_len == 0:
|
if seq_len == 0:
|
||||||
return
|
return
|
||||||
|
|
@ -282,7 +286,7 @@ class KvcacheView:
|
||||||
self._page_table = page_table
|
self._page_table = page_table
|
||||||
self._total_len = total_len
|
self._total_len = total_len
|
||||||
|
|
||||||
def write(self, layer_id: int, k: Tensor, v: Tensor):
|
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||||
start_pos = self._total_len - k.size(1)
|
start_pos = self._total_len - k.size(1)
|
||||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
|
|
@ -335,7 +339,7 @@ class KVCache:
|
||||||
self._table.set(task_id, hits + new_pages, cached)
|
self._table.set(task_id, hits + new_pages, cached)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def task_free(self, task_id: str):
|
def task_free(self, task_id: str) -> None:
|
||||||
page_table, _ = self._table.pop(task_id)
|
page_table, _ = self._table.pop(task_id)
|
||||||
for idx in page_table:
|
for idx in page_table:
|
||||||
self._pool.free(idx)
|
self._pool.free(idx)
|
||||||
|
|
@ -355,7 +359,7 @@ class KVCache:
|
||||||
|
|
||||||
def task_record_hashes(
|
def task_record_hashes(
|
||||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||||
):
|
) -> None:
|
||||||
page_table = self._table.get(task_id)
|
page_table = self._table.get(task_id)
|
||||||
full_pages = len(prompt_ids) // self.page_size
|
full_pages = len(prompt_ids) // self.page_size
|
||||||
for i in range(start_logical_page, full_pages):
|
for i in range(start_logical_page, full_pages):
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,9 @@ class Executor:
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
|
def execute_prefill(
|
||||||
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||||
|
) -> None:
|
||||||
if start_pos >= prompt_len:
|
if start_pos >= prompt_len:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,19 +71,18 @@ class InferenceScheduler:
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._fatal_error: Optional[Exception] = None
|
|
||||||
|
|
||||||
def add_task(self, prompt: str, **kwargs) -> str:
|
def add_task(self, prompt: str, **kwargs) -> str:
|
||||||
return self._task_mgr.add_task(prompt, **kwargs)
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
||||||
def remove_task(self, task_id: str):
|
def remove_task(self, task_id: str) -> None:
|
||||||
for task in self._task_mgr.remove_task(task_id):
|
for task in self._task_mgr.remove_task(task_id):
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return self._task_mgr.get_stats()
|
return self._task_mgr.get_stats()
|
||||||
|
|
||||||
def _run_generation_loop(self):
|
def _run_generation_loop(self) -> None:
|
||||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
|
|
@ -109,10 +108,7 @@ class InferenceScheduler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_prefill = [
|
to_prefill = [
|
||||||
t
|
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
||||||
for t in self._task_mgr.get_active_tasks()
|
|
||||||
if t.output_tokens == 0
|
|
||||||
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
|
||||||
]
|
]
|
||||||
if to_prefill:
|
if to_prefill:
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
|
|
@ -160,15 +156,11 @@ class InferenceScheduler:
|
||||||
t.output_ids.append(ntok)
|
t.output_ids.append(ntok)
|
||||||
t.output_tokens += 1
|
t.output_tokens += 1
|
||||||
pos = t.input_tokens + t.output_tokens
|
pos = t.input_tokens + t.output_tokens
|
||||||
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
self._page_cache.task_extend(t.task_id, pos)
|
||||||
if t.stream_callback:
|
if t.stream_callback:
|
||||||
t.stream_callback(
|
t.stream_callback(
|
||||||
self._task_mgr.tokenizer.decode([ntok])
|
self._task_mgr.tokenizer.decode([ntok])
|
||||||
)
|
)
|
||||||
if not extend_ok:
|
|
||||||
t.status = TaskStatus.ABORTED
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
||||||
for t in valid:
|
for t in valid:
|
||||||
if t.is_finished(stop_ids):
|
if t.is_finished(stop_ids):
|
||||||
|
|
@ -176,37 +168,28 @@ class InferenceScheduler:
|
||||||
t.stream_callback(STOP)
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._fatal_error = e
|
|
||||||
self._running = False
|
|
||||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
for task in self._task_mgr.get_waiting_tasks():
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
|
raise
|
||||||
|
|
||||||
def start(self):
|
def start(self) -> None:
|
||||||
if not self._running:
|
if not self._running:
|
||||||
self._running = True
|
self._running = True
|
||||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
self._loop_thread = t
|
self._loop_thread = t
|
||||||
|
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task_mgr.wake()
|
self._task_mgr.wake()
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=2.0)
|
self._loop_thread.join(timeout=2.0)
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
for task in self._task_mgr.get_waiting_tasks():
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -172,12 +172,12 @@ class TaskManager:
|
||||||
to_add.append(self.waiting_queue.popleft())
|
to_add.append(self.waiting_queue.popleft())
|
||||||
return to_add
|
return to_add
|
||||||
|
|
||||||
def activate(self, task: Task):
|
def activate(self, task: Task) -> None:
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
def return_to_waiting(self, tasks: List[Task]):
|
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for task in reversed(tasks):
|
for task in reversed(tasks):
|
||||||
self.waiting_queue.appendleft(task)
|
self.waiting_queue.appendleft(task)
|
||||||
|
|
@ -185,25 +185,18 @@ class TaskManager:
|
||||||
def has_work(self) -> bool:
|
def has_work(self) -> bool:
|
||||||
return bool(self.active_tasks or self.waiting_queue)
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
||||||
def wait_for_tasks(self, timeout: float = 1.0):
|
def wait_for_tasks(self, timeout: float = 1.0) -> None:
|
||||||
with self._lock:
|
self._task_event.clear()
|
||||||
if self.waiting_queue or self.active_tasks:
|
|
||||||
return
|
|
||||||
self._task_event.clear()
|
|
||||||
self._task_event.wait(timeout=timeout)
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
def get_active_tasks(self) -> List[Task]:
|
def get_active_tasks(self) -> List[Task]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return list(self.active_tasks)
|
return list(self.active_tasks)
|
||||||
|
|
||||||
def get_waiting_tasks(self) -> List[Task]:
|
def clear_queues(self) -> None:
|
||||||
with self._lock:
|
|
||||||
return list(self.waiting_queue)
|
|
||||||
|
|
||||||
def clear_queues(self):
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue.clear()
|
self.waiting_queue.clear()
|
||||||
self.active_tasks.clear()
|
self.active_tasks.clear()
|
||||||
|
|
||||||
def wake(self):
|
def wake(self) -> None:
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,17 @@ from astrai.inference.core.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_sampling_params(
|
||||||
|
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||||
|
):
|
||||||
|
if not (isinstance(top_k, int) and top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||||
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
|
|
||||||
class GenerateResult:
|
class GenerateResult:
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
|
@ -48,7 +59,7 @@ class GenerateResult:
|
||||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
return self._event.wait(timeout=timeout)
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
def wait_completion(self, timeout: float = 300.0):
|
def wait_completion(self, timeout: float = 300.0) -> None:
|
||||||
with self._cond:
|
with self._cond:
|
||||||
if not self._cond.wait_for(
|
if not self._cond.wait_for(
|
||||||
lambda: self._completed >= self._total, timeout=timeout
|
lambda: self._completed >= self._total, timeout=timeout
|
||||||
|
|
@ -75,12 +86,7 @@ class GenerationRequest:
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
if not (isinstance(top_k, int) and top_k >= 0):
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
|
||||||
raise ValueError("temperature must be a positive number")
|
|
||||||
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
@ -131,6 +137,7 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
|
|
@ -151,6 +158,7 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
sync_gen = self._generate_streaming(
|
sync_gen = self._generate_streaming(
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
@ -281,7 +289,7 @@ class InferenceEngine:
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return self.scheduler.get_stats()
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self) -> None:
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -44,12 +44,10 @@ class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
t = self.temperature
|
t = self.temperature
|
||||||
if isinstance(t, Tensor):
|
if isinstance(t, Tensor):
|
||||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
|
||||||
t = torch.clamp(t, min=1e-8)
|
|
||||||
if (t != 1.0).any():
|
if (t != 1.0).any():
|
||||||
logits = logits / t
|
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||||
elif t != 1.0:
|
elif t != 1.0:
|
||||||
logits = logits / max(t, 1e-8)
|
logits = logits / t
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,6 @@ from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.components.attention import GQA
|
from astrai.model.components.attention import GQA
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.model.components.linear import Linear
|
from astrai.model.components.linear import Linear
|
||||||
from astrai.model.components.lora import (
|
|
||||||
LoRAConfig,
|
|
||||||
inject_lora,
|
|
||||||
load_lora,
|
|
||||||
merge_lora,
|
|
||||||
save_lora,
|
|
||||||
)
|
|
||||||
from astrai.model.components.mlp import MLP
|
from astrai.model.components.mlp import MLP
|
||||||
from astrai.model.components.norm import RMSNorm
|
from astrai.model.components.norm import RMSNorm
|
||||||
from astrai.model.encoder import EmbeddingEncoder
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
|
@ -25,10 +18,4 @@ __all__ = [
|
||||||
"AutoRegressiveLM",
|
"AutoRegressiveLM",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
# LoRA
|
|
||||||
"LoRAConfig",
|
|
||||||
"inject_lora",
|
|
||||||
"merge_lora",
|
|
||||||
"save_lora",
|
|
||||||
"load_lora",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -2,24 +2,21 @@
|
||||||
AutoModel base class for model loading and saving.
|
AutoModel base class for model loading and saving.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Self, Union
|
from typing import Self, Union
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.serialization import load_model_config, load_model_weights, save_model
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _disable_random_init(enable: bool = True):
|
def _disable_random_init(enable: bool = True):
|
||||||
if not enable:
|
init_functions = [
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
names = (
|
|
||||||
"xavier_normal_",
|
"xavier_normal_",
|
||||||
"xavier_uniform_",
|
"xavier_uniform_",
|
||||||
"kaiming_normal_",
|
"kaiming_normal_",
|
||||||
|
|
@ -29,15 +26,18 @@ def _disable_random_init(enable: bool = True):
|
||||||
"constant_",
|
"constant_",
|
||||||
"normal_",
|
"normal_",
|
||||||
"uniform_",
|
"uniform_",
|
||||||
)
|
]
|
||||||
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
original_funcs = {}
|
||||||
for n in orig:
|
for name in init_functions:
|
||||||
setattr(nn.init, n, lambda *a, **kw: None)
|
if enable and hasattr(nn.init, name):
|
||||||
|
original_funcs[name] = getattr(nn.init, name)
|
||||||
|
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
for n, fn in orig.items():
|
if enable:
|
||||||
setattr(nn.init, n, fn)
|
for name, orig_func in original_funcs.items():
|
||||||
|
setattr(nn.init, name, orig_func)
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
|
|
@ -60,22 +60,25 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
|
|
||||||
model_path = Path(path)
|
model_path = Path(path)
|
||||||
|
|
||||||
|
# Load config
|
||||||
config_path = model_path / "config.json"
|
config_path = model_path / "config.json"
|
||||||
if not config_path.exists():
|
if config_path.exists():
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
config = ConfigFactory.load(raw)
|
||||||
|
model_type = config.model_type or "autoregressive_lm"
|
||||||
|
else:
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
raw = load_model_config(str(model_path))
|
|
||||||
config = ConfigFactory.load(raw)
|
|
||||||
model_type = config.model_type or "autoregressive_lm"
|
|
||||||
|
|
||||||
actual_cls = AutoModel.get_component_class(model_type)
|
actual_cls = AutoModel.get_component_class(model_type)
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
with _disable_random_init(enable=disable_random_init):
|
||||||
model = actual_cls(config)
|
model = actual_cls(config)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
weights_path = model_path / "model.safetensors"
|
weights_path = model_path / "model.safetensors"
|
||||||
if weights_path.exists():
|
if weights_path.exists():
|
||||||
state_dict = load_model_weights(str(model_path))
|
state_dict = st.load_file(str(weights_path))
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
@ -83,12 +86,15 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, Path],
|
save_directory: Union[str, Path],
|
||||||
):
|
) -> None:
|
||||||
save_model(
|
save_path = Path(save_directory)
|
||||||
config=self.config.to_dict(),
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
state_dict=self.state_dict(),
|
|
||||||
save_directory=str(save_directory),
|
# Save config
|
||||||
)
|
self.config.to_file(str(save_path / "config.json"))
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> Self:
|
def to(self, *args, **kwargs) -> Self:
|
||||||
"""Move model to device/dtype."""
|
"""Move model to device/dtype."""
|
||||||
|
|
|
||||||
|
|
@ -1,194 +0,0 @@
|
||||||
import logging
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Set
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
from astrai.serialization import (
|
|
||||||
load_json,
|
|
||||||
load_safetensors,
|
|
||||||
save_json,
|
|
||||||
save_safetensors,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
|
||||||
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoRAConfig:
|
|
||||||
r: int = 16
|
|
||||||
alpha: int = 32
|
|
||||||
target_modules: tuple = ("q_proj", "v_proj")
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALinear(nn.Module):
|
|
||||||
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
|
||||||
super().__init__()
|
|
||||||
self.register_parameter("weight", base.weight)
|
|
||||||
self.weight.requires_grad_(False)
|
|
||||||
self.bias = base.bias
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias.requires_grad_(False)
|
|
||||||
|
|
||||||
self.r = r
|
|
||||||
self.scaling = alpha / r
|
|
||||||
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
|
||||||
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
|
||||||
self._merged = False
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.linear(x, self.weight, self.bias)
|
|
||||||
if not self._merged:
|
|
||||||
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
|
||||||
return out
|
|
||||||
|
|
||||||
def merge(self):
|
|
||||||
if self._merged:
|
|
||||||
return
|
|
||||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self._merged = True
|
|
||||||
del self.lora_A
|
|
||||||
del self.lora_B
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_lora_info(model: nn.Module) -> dict:
|
|
||||||
names = {}
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if isinstance(m, Linear):
|
|
||||||
_, _, child = n.rpartition(".")
|
|
||||||
names.setdefault(child, []).append(n)
|
|
||||||
return names
|
|
||||||
|
|
||||||
|
|
||||||
def _get_lora_count(model: nn.Module) -> int:
|
|
||||||
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
|
||||||
|
|
||||||
|
|
||||||
def inject_lora(
|
|
||||||
model: nn.Module,
|
|
||||||
r: int = 16,
|
|
||||||
alpha: int = 32,
|
|
||||||
target_modules: Optional[Set[str]] = None,
|
|
||||||
) -> LoRAConfig:
|
|
||||||
if target_modules is None:
|
|
||||||
target_modules = TARGET_MODULES_ATTN
|
|
||||||
|
|
||||||
available = _collect_lora_info(model)
|
|
||||||
injected = 0
|
|
||||||
|
|
||||||
for name, module in list(model.named_modules()):
|
|
||||||
if not isinstance(module, Linear):
|
|
||||||
continue
|
|
||||||
parent_name, _, child_name = name.rpartition(".")
|
|
||||||
if child_name not in target_modules:
|
|
||||||
continue
|
|
||||||
parent = model.get_submodule(parent_name) if parent_name else model
|
|
||||||
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
|
||||||
injected += 1
|
|
||||||
|
|
||||||
if injected == 0:
|
|
||||||
logger.warning(
|
|
||||||
"No LoRA layers injected. Available Linear child names: %s. "
|
|
||||||
"target_modules: %s. Check model type and target_modules.",
|
|
||||||
sorted(available),
|
|
||||||
sorted(target_modules),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
|
||||||
|
|
||||||
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
|
||||||
|
|
||||||
|
|
||||||
def merge_lora(model: nn.Module):
|
|
||||||
n = 0
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, LoRALinear):
|
|
||||||
module.merge()
|
|
||||||
n += 1
|
|
||||||
if n == 0:
|
|
||||||
logger.warning("No LoRA layers to merge.")
|
|
||||||
else:
|
|
||||||
logger.info("Merged %d LoRA layers", n)
|
|
||||||
|
|
||||||
|
|
||||||
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
|
||||||
lora_sd = {
|
|
||||||
k: v
|
|
||||||
for k, v in model.state_dict().items()
|
|
||||||
if k.endswith((".lora_A", ".lora_B"))
|
|
||||||
}
|
|
||||||
if not lora_sd:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LoRA parameters found in model. "
|
|
||||||
"The model may not have been injected or was already merged."
|
|
||||||
)
|
|
||||||
|
|
||||||
path = Path(save_dir)
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
|
||||||
save_json(asdict(config), path / "adapter_config.json")
|
|
||||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
|
||||||
path = Path(load_dir)
|
|
||||||
raw = load_json(path / "adapter_config.json")
|
|
||||||
config = LoRAConfig(
|
|
||||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
|
||||||
)
|
|
||||||
|
|
||||||
existing = _get_lora_count(model)
|
|
||||||
if existing > 0:
|
|
||||||
logger.warning(
|
|
||||||
"Model already has %d LoRA layers. Skipping injection, "
|
|
||||||
"loading weights onto existing layers only.",
|
|
||||||
existing,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inject_lora(
|
|
||||||
model,
|
|
||||||
r=config.r,
|
|
||||||
alpha=config.alpha,
|
|
||||||
target_modules=set(config.target_modules),
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = load_safetensors(path / "adapter_model.safetensors")
|
|
||||||
try:
|
|
||||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
|
||||||
except RuntimeError as e:
|
|
||||||
msg = str(e)
|
|
||||||
if "size mismatch" in msg:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"LoRA weight shapes do not match the model. "
|
|
||||||
f"The adapter config (r={config.r}) may not match the injected layers. "
|
|
||||||
f"Original error: {msg}"
|
|
||||||
) from e
|
|
||||||
raise
|
|
||||||
|
|
||||||
injected = _get_lora_count(model)
|
|
||||||
if injected == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LoRA layers found after loading. "
|
|
||||||
"Inject LoRA before calling load_lora, or check the adapter config."
|
|
||||||
)
|
|
||||||
|
|
||||||
if missing:
|
|
||||||
lora_missing = [k for k in missing if "lora" in k]
|
|
||||||
if lora_missing:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"LoRA weight keys not found in model: {lora_missing}. "
|
|
||||||
f"The adapter config (r={config.r}) may not match the model."
|
|
||||||
)
|
|
||||||
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
|
||||||
if unexpected:
|
|
||||||
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
|
||||||
|
|
||||||
logger.info("LoRA adapter loaded from %s", load_dir)
|
|
||||||
return config
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -19,10 +19,6 @@ def get_rotary_emb(
|
||||||
return torch.complex(cos, sin)
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
def ntk_base(base: float, dim: int, factor: float) -> float:
|
|
||||||
return base * (factor ** (dim / (dim - 2)))
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||||
|
|
@ -34,25 +30,11 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
max_len: int,
|
|
||||||
base: float = 10000,
|
|
||||||
rope_scaling: Optional[Dict] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self.base = base
|
self.base = base
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
|
|
||||||
if rope_scaling is not None:
|
|
||||||
scaling_type = rope_scaling.get("type", "ntk")
|
|
||||||
factor = rope_scaling.get("factor", 1.0)
|
|
||||||
if scaling_type == "ntk":
|
|
||||||
self.base = ntk_base(base, dim, factor)
|
|
||||||
|
|
||||||
self._set_rotary_buffer(self.max_len)
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,7 @@ class EmbeddingEncoder(AutoModel):
|
||||||
self.config = config
|
self.config = config
|
||||||
rope_dim = config.dim // config.n_heads
|
rope_dim = config.dim // config.n_heads
|
||||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
self.rotary_embedding = RotaryEmbedding(
|
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
|
||||||
)
|
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
|
@ -68,6 +66,9 @@ class EmbeddingEncoder(AutoModel):
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||||
|
|
||||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, Mapping, Optional
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -26,21 +26,24 @@ def process_attention_mask(
|
||||||
return input_mask
|
return input_mask
|
||||||
|
|
||||||
device = input_tensor.device
|
device = input_tensor.device
|
||||||
B = input_tensor.size(0)
|
dtype = input_tensor.dtype
|
||||||
|
B, S = input_tensor.size()[:2]
|
||||||
T = position_ids.max().item() + 1
|
T = position_ids.max().item() + 1
|
||||||
|
|
||||||
if input_mask is None:
|
if input_mask is None:
|
||||||
if position_ids.min().item() == 0 and is_causal:
|
if position_ids.min().item() == 0 and is_causal:
|
||||||
return None
|
return None
|
||||||
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
|
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
|
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
||||||
if is_causal:
|
if is_causal:
|
||||||
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||||
attend = attend & causal
|
|
||||||
|
|
||||||
return attend.unsqueeze(1)
|
return torch.full(
|
||||||
|
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||||
|
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("autoregressive_lm")
|
@AutoModel.register("autoregressive_lm")
|
||||||
|
|
@ -56,9 +59,7 @@ class AutoRegressiveLM(AutoModel):
|
||||||
else config.dim // config.n_heads
|
else config.dim // config.n_heads
|
||||||
)
|
)
|
||||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
self.rotary_embedding = RotaryEmbedding(
|
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
|
||||||
)
|
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
|
@ -133,7 +134,7 @@ class AutoRegressiveLM(AutoModel):
|
||||||
input_mask: Optional[Tensor] = None,
|
input_mask: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
position_ids: Optional[Tensor] = None,
|
position_ids: Optional[Tensor] = None,
|
||||||
) -> Dict[str, Tensor]:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,3 @@
|
||||||
from astrai.parallel.executor import (
|
|
||||||
AccumOptimizer,
|
|
||||||
AccumScheduler,
|
|
||||||
BaseExecutor,
|
|
||||||
DDPExecutor,
|
|
||||||
ExecutorFactory,
|
|
||||||
FSDPExecutor,
|
|
||||||
GradientState,
|
|
||||||
NoneExecutor,
|
|
||||||
)
|
|
||||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||||
from astrai.parallel.setup import (
|
from astrai.parallel.setup import (
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
|
@ -27,12 +17,4 @@ __all__ = [
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
"ColumnParallelLinear",
|
"ColumnParallelLinear",
|
||||||
"ExecutorFactory",
|
|
||||||
"BaseExecutor",
|
|
||||||
"GradientState",
|
|
||||||
"AccumOptimizer",
|
|
||||||
"AccumScheduler",
|
|
||||||
"NoneExecutor",
|
|
||||||
"DDPExecutor",
|
|
||||||
"FSDPExecutor",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,272 +0,0 @@
|
||||||
"""Unified training executor — parallel strategy + gradient accumulation."""
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.parallel.setup import get_rank, get_world_size
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GradientState:
|
|
||||||
def __init__(self, grad_accum_steps: int = 1):
|
|
||||||
self.num_steps = max(grad_accum_steps, 1)
|
|
||||||
self._step: int = 0
|
|
||||||
self._sync_gradients: bool = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sync_gradients(self) -> bool:
|
|
||||||
return self._sync_gradients
|
|
||||||
|
|
||||||
def _do_sync(self):
|
|
||||||
self._step += 1
|
|
||||||
self._sync_gradients = self._step % self.num_steps == 0
|
|
||||||
|
|
||||||
|
|
||||||
class AccumOptimizer:
|
|
||||||
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.gradient_state = gradient_state
|
|
||||||
|
|
||||||
def step(self, closure=None):
|
|
||||||
if self.gradient_state.sync_gradients:
|
|
||||||
self.optimizer.step(closure)
|
|
||||||
|
|
||||||
def zero_grad(self):
|
|
||||||
if self.gradient_state.sync_gradients:
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def param_groups(self):
|
|
||||||
return self.optimizer.param_groups
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
return self.optimizer.state_dict()
|
|
||||||
|
|
||||||
def load_state_dict(self, d):
|
|
||||||
self.optimizer.load_state_dict(d)
|
|
||||||
|
|
||||||
|
|
||||||
class AccumScheduler:
|
|
||||||
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
|
||||||
self.scheduler = scheduler
|
|
||||||
self.gradient_state = gradient_state
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
if self.gradient_state.sync_gradients:
|
|
||||||
self.scheduler.step()
|
|
||||||
|
|
||||||
def state_dict(self):
|
|
||||||
return self.scheduler.state_dict()
|
|
||||||
|
|
||||||
def load_state_dict(self, d):
|
|
||||||
self.scheduler.load_state_dict(d)
|
|
||||||
|
|
||||||
def get_last_lr(self):
|
|
||||||
return self.scheduler.get_last_lr()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseExecutor:
|
|
||||||
def __init__(self, grad_accum_steps: int = 1):
|
|
||||||
self.gradient_state = GradientState(grad_accum_steps)
|
|
||||||
|
|
||||||
def prepare(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
optimizer: Optional[Optimizer] = None,
|
|
||||||
dataloader: Optional[DataLoader] = None,
|
|
||||||
scheduler: Optional[LRScheduler] = None,
|
|
||||||
) -> Tuple[
|
|
||||||
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
|
|
||||||
]:
|
|
||||||
model = self._prepare_model(model)
|
|
||||||
if optimizer is not None:
|
|
||||||
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
|
||||||
if scheduler is not None:
|
|
||||||
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
|
||||||
return model, optimizer, dataloader, scheduler
|
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def accumulate(self, model: nn.Module):
|
|
||||||
self.gradient_state._do_sync()
|
|
||||||
if not self.gradient_state.sync_gradients:
|
|
||||||
with self._no_sync(model):
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def use_distributed(self) -> bool:
|
|
||||||
return get_world_size() > 1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sync_gradients(self) -> bool:
|
|
||||||
return self.gradient_state.sync_gradients
|
|
||||||
|
|
||||||
@property
|
|
||||||
def grad_accum_steps(self) -> int:
|
|
||||||
return self.gradient_state.num_steps
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("none")
|
|
||||||
class NoneExecutor(BaseExecutor):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("ddp")
|
|
||||||
class DDPExecutor(BaseExecutor):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grad_accum_steps: int = 1,
|
|
||||||
dim: int = 0,
|
|
||||||
broadcast_buffers: bool = True,
|
|
||||||
init_sync: bool = True,
|
|
||||||
process_group=None,
|
|
||||||
bucket_cap_mb: int = 25,
|
|
||||||
find_unused_parameters: bool = False,
|
|
||||||
check_reduction: bool = False,
|
|
||||||
gradient_as_bucket_view: bool = False,
|
|
||||||
static_graph: bool = False,
|
|
||||||
delay_all_reduce_named_params=None,
|
|
||||||
param_to_hook_all_reduce=None,
|
|
||||||
mixed_precision=None,
|
|
||||||
device_mesh=None,
|
|
||||||
):
|
|
||||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
|
||||||
self._ddp_kwargs = dict(
|
|
||||||
dim=dim,
|
|
||||||
broadcast_buffers=broadcast_buffers,
|
|
||||||
init_sync=init_sync,
|
|
||||||
process_group=process_group,
|
|
||||||
bucket_cap_mb=bucket_cap_mb,
|
|
||||||
find_unused_parameters=find_unused_parameters,
|
|
||||||
check_reduction=check_reduction,
|
|
||||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
|
||||||
static_graph=static_graph,
|
|
||||||
delay_all_reduce_named_params=delay_all_reduce_named_params,
|
|
||||||
param_to_hook_all_reduce=param_to_hook_all_reduce,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
device_mesh=device_mesh,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
if not self.use_distributed:
|
|
||||||
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
|
||||||
return model
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
|
|
||||||
model = DDP(
|
|
||||||
model,
|
|
||||||
device_ids=[local_rank],
|
|
||||||
output_device=local_rank,
|
|
||||||
**self._ddp_kwargs,
|
|
||||||
)
|
|
||||||
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
return model.no_sync()
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
return model.module.state_dict()
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("fsdp")
|
|
||||||
class FSDPExecutor(BaseExecutor):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grad_accum_steps: int = 1,
|
|
||||||
process_group=None,
|
|
||||||
sharding_strategy=None,
|
|
||||||
cpu_offload=None,
|
|
||||||
auto_wrap_policy=None,
|
|
||||||
backward_prefetch=None,
|
|
||||||
mixed_precision=None,
|
|
||||||
ignored_modules=None,
|
|
||||||
param_init_fn=None,
|
|
||||||
sync_module_states: bool = False,
|
|
||||||
forward_prefetch: bool = False,
|
|
||||||
limit_all_gathers: bool = True,
|
|
||||||
ignored_states=None,
|
|
||||||
device_mesh=None,
|
|
||||||
):
|
|
||||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
|
||||||
self._fsdp_kwargs = {
|
|
||||||
k: v
|
|
||||||
for k, v in dict(
|
|
||||||
process_group=process_group,
|
|
||||||
sharding_strategy=sharding_strategy,
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
auto_wrap_policy=auto_wrap_policy,
|
|
||||||
backward_prefetch=backward_prefetch,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
ignored_modules=ignored_modules,
|
|
||||||
param_init_fn=param_init_fn,
|
|
||||||
sync_module_states=sync_module_states,
|
|
||||||
forward_prefetch=forward_prefetch,
|
|
||||||
limit_all_gathers=limit_all_gathers,
|
|
||||||
use_orig_params=True,
|
|
||||||
ignored_states=ignored_states,
|
|
||||||
device_mesh=device_mesh,
|
|
||||||
).items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
self._original_model: Optional[nn.Module] = None
|
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
if not self.use_distributed:
|
|
||||||
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
|
||||||
return model
|
|
||||||
self._original_model = model
|
|
||||||
device_id = torch.device("cuda", get_rank())
|
|
||||||
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
|
||||||
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
|
||||||
if isinstance(model, FSDP):
|
|
||||||
return model.no_sync()
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
|
||||||
if isinstance(model, FSDP) and self.use_distributed:
|
|
||||||
with FSDP.state_dict_type(
|
|
||||||
model,
|
|
||||||
StateDictType.FULL_STATE_DICT,
|
|
||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
|
||||||
):
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
@ -31,7 +30,6 @@ def get_rank() -> int:
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
local_rank: int,
|
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
|
|
@ -43,18 +41,14 @@ def setup_parallel(
|
||||||
return
|
return
|
||||||
|
|
||||||
if world_size <= 1:
|
if world_size <= 1:
|
||||||
device_id = torch.device(device_type, local_rank)
|
|
||||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
|
||||||
os.environ["WORLD_SIZE"] = "1"
|
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
device_id = torch.device(device_type, local_rank)
|
device_id = torch.device(device_type, rank)
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ["MASTER_PORT"] = master_port
|
os.environ["MASTER_PORT"] = master_port
|
||||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
|
|
@ -96,7 +90,7 @@ def only_on_rank(rank, sync=False):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _run_single_rank(
|
def wrapper_spawn_func(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
|
|
@ -106,108 +100,20 @@ def _run_single_rank(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
with setup_parallel(
|
try:
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
local_rank=rank,
|
|
||||||
backend=backend,
|
|
||||||
master_addr=master_addr,
|
|
||||||
master_port=master_port,
|
|
||||||
device_type=device_type,
|
|
||||||
):
|
|
||||||
func(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LaunchStrategy(ABC):
|
|
||||||
"""Strategy for launching a function in a distributed context."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
world_size: int,
|
|
||||||
backend: str,
|
|
||||||
master_addr: str,
|
|
||||||
master_port: str,
|
|
||||||
device_type: str,
|
|
||||||
start_method: str,
|
|
||||||
):
|
|
||||||
self.world_size = world_size
|
|
||||||
self.backend = backend
|
|
||||||
self.master_addr = master_addr
|
|
||||||
self.master_port = master_port
|
|
||||||
self.device_type = device_type
|
|
||||||
self.start_method = start_method
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def launch(self, func: Callable, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class TorchrunStrategy(LaunchStrategy):
|
|
||||||
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
|
|
||||||
|
|
||||||
def launch(self, func: Callable, **kwargs):
|
|
||||||
rank = int(os.environ["RANK"])
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
|
||||||
with setup_parallel(
|
with setup_parallel(
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
local_rank=local_rank,
|
backend=backend,
|
||||||
backend=self.backend,
|
master_addr=master_addr,
|
||||||
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
|
master_port=master_port,
|
||||||
master_port=os.environ.get("MASTER_PORT", self.master_port),
|
device_type=device_type,
|
||||||
device_type=self.device_type,
|
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
class LocalStrategy(LaunchStrategy):
|
print(f"Error in rank {rank}: {e}")
|
||||||
"""Local launcher — single-process or mp.start_processes."""
|
raise
|
||||||
|
|
||||||
def launch(self, func: Callable, **kwargs):
|
|
||||||
args = (
|
|
||||||
self.world_size,
|
|
||||||
self.backend,
|
|
||||||
self.master_addr,
|
|
||||||
self.master_port,
|
|
||||||
self.device_type,
|
|
||||||
func,
|
|
||||||
kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.world_size == 1:
|
|
||||||
_run_single_rank(0, *args)
|
|
||||||
return
|
|
||||||
|
|
||||||
ctx = mp.start_processes(
|
|
||||||
_run_single_rank,
|
|
||||||
args=args,
|
|
||||||
nprocs=self.world_size,
|
|
||||||
start_method=self.start_method,
|
|
||||||
join=False,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
while not ctx.join():
|
|
||||||
pass
|
|
||||||
except BaseException:
|
|
||||||
for p in ctx.processes:
|
|
||||||
p.terminate()
|
|
||||||
ctx.join()
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_launcher() -> str:
|
|
||||||
"""Detect the distributed launcher from environment.
|
|
||||||
|
|
||||||
Returns one of: "torchelastic", "torchrun", "external", "local".
|
|
||||||
"""
|
|
||||||
if dist.is_torchelastic_launched():
|
|
||||||
return "torchelastic"
|
|
||||||
if "LOCAL_WORLD_SIZE" in os.environ:
|
|
||||||
return "torchrun"
|
|
||||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
|
||||||
return "external"
|
|
||||||
return "local"
|
|
||||||
|
|
||||||
|
|
||||||
def spawn_parallel_fn(
|
def spawn_parallel_fn(
|
||||||
|
|
@ -220,13 +126,41 @@ def spawn_parallel_fn(
|
||||||
start_method: str = "spawn",
|
start_method: str = "spawn",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
launcher = _detect_launcher()
|
# clear environment variables
|
||||||
if launcher in ("torchelastic", "torchrun", "external"):
|
for key in [
|
||||||
strategy = TorchrunStrategy(
|
"MASTER_ADDR",
|
||||||
world_size, backend, master_addr, master_port, device_type, start_method
|
"MASTER_PORT",
|
||||||
)
|
"RANK",
|
||||||
else:
|
"WORLD_SIZE",
|
||||||
strategy = LocalStrategy(
|
"LOCAL_RANK",
|
||||||
world_size, backend, master_addr, master_port, device_type, start_method
|
"LOCAL_DEVICE",
|
||||||
)
|
]:
|
||||||
strategy.launch(func, **kwargs)
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
|
if world_size == 1:
|
||||||
|
device_id = torch.device(device_type, 0)
|
||||||
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
|
func(**kwargs)
|
||||||
|
return
|
||||||
|
|
||||||
|
wrapper_spawn_func_args = (
|
||||||
|
world_size,
|
||||||
|
backend,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
device_type,
|
||||||
|
func,
|
||||||
|
kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
mp.start_processes(
|
||||||
|
wrapper_spawn_func,
|
||||||
|
args=wrapper_spawn_func_args,
|
||||||
|
nprocs=world_size,
|
||||||
|
start_method=start_method,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
from astrai.preprocessing.builder import (
|
|
||||||
BaseMaskBuilder,
|
|
||||||
MaskBuilderFactory,
|
|
||||||
SectionedMaskBuilder,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseMaskBuilder",
|
|
||||||
"MaskBuilderFactory",
|
|
||||||
"SectionedMaskBuilder",
|
|
||||||
"Pipeline",
|
|
||||||
"filter_by_length",
|
|
||||||
]
|
|
||||||
|
|
@ -1,338 +0,0 @@
|
||||||
"""Mask building strategies for preprocessing pipeline.
|
|
||||||
|
|
||||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
|
||||||
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
|
|
||||||
for single-output or ``input.sources`` for multi-output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMaskBuilder(ABC):
|
|
||||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
|
||||||
|
|
||||||
Returns ``None`` to skip the item entirely.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, BaseMaskBuilder):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
|
||||||
if not domain_key:
|
|
||||||
return "__default__"
|
|
||||||
val = item.get(domain_key, "__default__")
|
|
||||||
return val if isinstance(val, str) else "__default__"
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_action(action: str, role: str, config) -> str:
|
|
||||||
"""Resolve action to "train" or "mask".
|
|
||||||
|
|
||||||
- ``"train"`` / ``"mask"`` → literal
|
|
||||||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
|
||||||
"""
|
|
||||||
if action == "$role":
|
|
||||||
return config.mask.get(role, config.mask_default)
|
|
||||||
return action
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("sectioned")
|
|
||||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
|
||||||
"""Config-driven builder supporting single and multi-output modes.
|
|
||||||
|
|
||||||
Single-output (backward-compatible)::
|
|
||||||
|
|
||||||
{"input": {"sections": [
|
|
||||||
{"field": "messages", "action": "$role", "template": true}
|
|
||||||
]}}
|
|
||||||
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
|
||||||
|
|
||||||
Multi-output (DPO / GRPO)::
|
|
||||||
|
|
||||||
{"input": {"sources": {
|
|
||||||
"chosen": {"sections": [
|
|
||||||
{"field": "chosen", "action": "$role", "template": true}
|
|
||||||
]},
|
|
||||||
"rejected": {"sections": [
|
|
||||||
{"field": "rejected", "action": "$role", "template": true}
|
|
||||||
]}
|
|
||||||
}}}
|
|
||||||
→ {"chosen": [...], "chosen_mask": [...],
|
|
||||||
"rejected": [...], "rejected_mask": [...], "domain": "..."}
|
|
||||||
|
|
||||||
Output spec fields::
|
|
||||||
|
|
||||||
sections – list of section specs (same format as single-output)
|
|
||||||
list_field – True when the JSONL field holds a list of values to
|
|
||||||
tokenise individually and concatenate (GRPO responses)
|
|
||||||
mask_key – explicit output key for the loss mask
|
|
||||||
(default: ``"{output_key}_mask"``)
|
|
||||||
dtype – explicit tensor dtype for this output key
|
|
||||||
(default: "int32")
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
sources_spec = getattr(config.input, "sources", None)
|
|
||||||
if sources_spec:
|
|
||||||
return self._build_multi(item, sources_spec, config, tokenizer)
|
|
||||||
return self._build_single(item, config, tokenizer)
|
|
||||||
|
|
||||||
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
sections = config.input.sections
|
|
||||||
if not sections:
|
|
||||||
return None
|
|
||||||
|
|
||||||
ids, mask = self._process_sections(
|
|
||||||
item, sections, config, tokenizer, is_top_level=True
|
|
||||||
)
|
|
||||||
if ids is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result: dict = {
|
|
||||||
"sequence": ids,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
if not all(m == 1 for m in mask):
|
|
||||||
result["loss_mask"] = mask
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _build_multi(
|
|
||||||
self, item: dict, sources_spec: dict, config, tokenizer
|
|
||||||
) -> Optional[dict]:
|
|
||||||
result: dict = {}
|
|
||||||
any_output = False
|
|
||||||
|
|
||||||
for output_key, spec in sources_spec.items():
|
|
||||||
sections = spec.get("sections", [])
|
|
||||||
if not sections:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self._is_value_section(sections):
|
|
||||||
ids = self._extract_raw_value(item, sections)
|
|
||||||
if ids is None:
|
|
||||||
continue
|
|
||||||
result[output_key] = ids
|
|
||||||
any_output = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
list_field = spec.get("list_field", False)
|
|
||||||
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
|
||||||
|
|
||||||
if list_field:
|
|
||||||
ids, mask = self._process_list_field(item, sections, config, tokenizer)
|
|
||||||
else:
|
|
||||||
ids, mask = self._process_sections(
|
|
||||||
item, sections, config, tokenizer, is_top_level=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if ids is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
result[output_key] = ids
|
|
||||||
if not all(m == 1 for m in mask):
|
|
||||||
result[mask_key] = mask
|
|
||||||
elif "mask_key" in spec:
|
|
||||||
result[mask_key] = mask
|
|
||||||
|
|
||||||
any_output = True
|
|
||||||
|
|
||||||
if not any_output:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result["domain"] = _extract_domain(item, config.output.domain_key)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_value_section(sections: list) -> bool:
|
|
||||||
return len(sections) == 1 and sections[0].get("action") == "value"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_raw_value(item: dict, sections: list):
|
|
||||||
"""Extract a raw value from a JSONL field without tokenisation.
|
|
||||||
|
|
||||||
Used for GRPO rewards where the field contains float values.
|
|
||||||
"""
|
|
||||||
sec = sections[0]
|
|
||||||
field = sec["field"]
|
|
||||||
raw = item.get(field)
|
|
||||||
if raw is None:
|
|
||||||
return None
|
|
||||||
if isinstance(raw, list):
|
|
||||||
return [float(v) for v in raw]
|
|
||||||
return [float(raw)]
|
|
||||||
|
|
||||||
def _process_sections(
|
|
||||||
self,
|
|
||||||
item: dict,
|
|
||||||
sections: list,
|
|
||||||
config,
|
|
||||||
tokenizer,
|
|
||||||
*,
|
|
||||||
is_top_level: bool = False,
|
|
||||||
):
|
|
||||||
"""Process a list of sections into ``(ids, loss_mask)``.
|
|
||||||
|
|
||||||
Returns ``(None, None)`` if the item should be skipped.
|
|
||||||
"""
|
|
||||||
all_ids: list[int] = []
|
|
||||||
loss_mask: list[int] = []
|
|
||||||
|
|
||||||
has_template = any(s.get("template") for s in sections)
|
|
||||||
is_text_config = not has_template and all(
|
|
||||||
s["action"] == "train" for s in sections
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_top_level and has_template and tokenizer.bos_token_id is not None:
|
|
||||||
all_ids.append(tokenizer.bos_token_id)
|
|
||||||
loss_mask.append(0)
|
|
||||||
|
|
||||||
first_section = True
|
|
||||||
for sec in sections:
|
|
||||||
field = sec["field"]
|
|
||||||
action = sec["action"]
|
|
||||||
use_template = sec.get("template", False)
|
|
||||||
add_special = sec.get(
|
|
||||||
"add_special_tokens", not use_template and first_section
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_template:
|
|
||||||
success = self._append_template_section(
|
|
||||||
item, field, action, tokenizer, config, all_ids, loss_mask
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
success = self._append_text_section(
|
|
||||||
item,
|
|
||||||
field,
|
|
||||||
action,
|
|
||||||
tokenizer,
|
|
||||||
add_special,
|
|
||||||
is_text_config,
|
|
||||||
config,
|
|
||||||
all_ids,
|
|
||||||
loss_mask,
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
continue
|
|
||||||
|
|
||||||
first_section = False
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
|
||||||
all_ids = all_ids[:max_len]
|
|
||||||
loss_mask = loss_mask[: len(all_ids)]
|
|
||||||
|
|
||||||
if not all_ids:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
if is_top_level and has_template and len(all_ids) <= 1:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
return all_ids, loss_mask
|
|
||||||
|
|
||||||
def _append_template_section(
|
|
||||||
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
|
||||||
):
|
|
||||||
messages = item.get(field)
|
|
||||||
if not isinstance(messages, list) or not messages:
|
|
||||||
return False
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
act = _resolve_action(action, role, config)
|
|
||||||
rendered = tokenizer.apply_chat_template(
|
|
||||||
[msg], tokenize=False, add_generation_prompt=False
|
|
||||||
)
|
|
||||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
val = 1 if act == "train" else 0
|
|
||||||
loss_mask.extend([val] * len(ids))
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _append_text_section(
|
|
||||||
self,
|
|
||||||
item,
|
|
||||||
field,
|
|
||||||
action,
|
|
||||||
tokenizer,
|
|
||||||
add_special,
|
|
||||||
is_text_config,
|
|
||||||
config,
|
|
||||||
all_ids,
|
|
||||||
loss_mask,
|
|
||||||
):
|
|
||||||
text = str(item.get(field, ""))
|
|
||||||
if not text.strip():
|
|
||||||
return False
|
|
||||||
if is_text_config:
|
|
||||||
pp = config.preprocessing
|
|
||||||
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
|
||||||
return False
|
|
||||||
if len(text) > pp.max_chars:
|
|
||||||
return False
|
|
||||||
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
val = 1 if action == "train" else 0
|
|
||||||
loss_mask.extend([val] * len(ids))
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
|
|
||||||
all_ids: list[int] = []
|
|
||||||
loss_mask: list[int] = []
|
|
||||||
|
|
||||||
for sec in sections:
|
|
||||||
field = sec["field"]
|
|
||||||
action = sec["action"]
|
|
||||||
use_template = sec.get("template", False)
|
|
||||||
|
|
||||||
values = item.get(field)
|
|
||||||
if not isinstance(values, list):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for val in values:
|
|
||||||
if use_template:
|
|
||||||
if isinstance(val, list):
|
|
||||||
wrapper = {field: val}
|
|
||||||
self._append_template_section(
|
|
||||||
wrapper,
|
|
||||||
field,
|
|
||||||
action,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
all_ids,
|
|
||||||
loss_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
wrapper = {field: str(val)}
|
|
||||||
self._append_text_section(
|
|
||||||
wrapper,
|
|
||||||
field,
|
|
||||||
action,
|
|
||||||
tokenizer,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
config,
|
|
||||||
all_ids,
|
|
||||||
loss_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
|
||||||
all_ids = all_ids[:max_len]
|
|
||||||
loss_mask = loss_mask[: len(all_ids)]
|
|
||||||
|
|
||||||
if not all_ids:
|
|
||||||
return None, None
|
|
||||||
return all_ids, loss_mask
|
|
||||||
|
|
@ -1,257 +0,0 @@
|
||||||
"""Config-driven JSONL preprocessing pipeline.
|
|
||||||
|
|
||||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
|
||||||
sharding and flush to ``.h5`` / ``.bin`` storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from itertools import chain
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
|
||||||
from astrai.dataset.storage import save_bin, save_h5
|
|
||||||
from astrai.preprocessing.builder import SectionedMaskBuilder
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
|
||||||
"bool": torch.bool,
|
|
||||||
"uint8": torch.uint8,
|
|
||||||
"int8": torch.int8,
|
|
||||||
"int16": torch.int16,
|
|
||||||
"int32": torch.int32,
|
|
||||||
"int64": torch.int64,
|
|
||||||
"float16": torch.float16,
|
|
||||||
"float32": torch.float32,
|
|
||||||
"float64": torch.float64,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
|
||||||
return min_len <= len(text) <= max_len
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
|
||||||
if len(seq) <= max_len:
|
|
||||||
return seq
|
|
||||||
if mode == "keep_end":
|
|
||||||
return seq[-max_len:]
|
|
||||||
return seq[:max_len]
|
|
||||||
|
|
||||||
|
|
||||||
def pack_sequences(
|
|
||||||
sequences: List[list],
|
|
||||||
max_packed_len: int,
|
|
||||||
strategy: str,
|
|
||||||
truncation_mode: str,
|
|
||||||
) -> List[Tuple[int, int]]:
|
|
||||||
"""Pack *sequences* into bins and return a reorder plan.
|
|
||||||
|
|
||||||
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
|
|
||||||
All keys (sequence, loss_mask, …) must be reordered and truncated
|
|
||||||
identically according to this plan.
|
|
||||||
|
|
||||||
Supported *strategy* values:
|
|
||||||
|
|
||||||
- ``"simple"``: sequential, no reordering.
|
|
||||||
- ``"bfd"``: best-fit decreasing bin packing.
|
|
||||||
"""
|
|
||||||
n = len(sequences)
|
|
||||||
if strategy == "simple":
|
|
||||||
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
|
|
||||||
|
|
||||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
|
||||||
bins: List[List[int]] = []
|
|
||||||
bin_lengths: List[int] = []
|
|
||||||
|
|
||||||
for orig_idx in order:
|
|
||||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
|
||||||
|
|
||||||
best_bin = None
|
|
||||||
best_remain = max_packed_len + 1
|
|
||||||
for i, bl in enumerate(bin_lengths):
|
|
||||||
remain = max_packed_len - bl
|
|
||||||
if seq_len <= remain < best_remain:
|
|
||||||
best_remain = remain
|
|
||||||
best_bin = i
|
|
||||||
|
|
||||||
if best_bin is not None:
|
|
||||||
bins[best_bin].append(orig_idx)
|
|
||||||
bin_lengths[best_bin] += seq_len
|
|
||||||
else:
|
|
||||||
bins.append([orig_idx])
|
|
||||||
bin_lengths.append(seq_len)
|
|
||||||
|
|
||||||
plan: List[Tuple[int, int]] = []
|
|
||||||
for bin_indices in bins:
|
|
||||||
for orig_idx in bin_indices:
|
|
||||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
|
||||||
|
|
||||||
return plan
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
|
||||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
|
||||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PipelineConfig,
|
|
||||||
input_paths: list[str],
|
|
||||||
output_dir: str,
|
|
||||||
tokenizer_path: str,
|
|
||||||
):
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
self.config = config
|
|
||||||
self.paths = input_paths
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.tokenizer_path = tokenizer_path
|
|
||||||
|
|
||||||
self.mask_builder = SectionedMaskBuilder()
|
|
||||||
|
|
||||||
def transform(self, item: dict) -> Optional[dict]:
|
|
||||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
|
||||||
domains: dict = defaultdict(lambda: defaultdict(list))
|
|
||||||
total_tokens = 0
|
|
||||||
shard_idx: dict[str, int] = defaultdict(int)
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
pp = self.config.preprocessing
|
|
||||||
|
|
||||||
for item in tqdm.tqdm(
|
|
||||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
|
||||||
):
|
|
||||||
if pp.max_items and count >= pp.max_items:
|
|
||||||
break
|
|
||||||
|
|
||||||
result = self.transform(item)
|
|
||||||
if result is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
domain = result.pop("domain", "__default__")
|
|
||||||
|
|
||||||
is_multi = bool(getattr(self.config.input, "sources", None))
|
|
||||||
if is_multi:
|
|
||||||
ids = self._primary_ids(result)
|
|
||||||
else:
|
|
||||||
ids = result.pop("sequence")
|
|
||||||
result["sequence"] = ids
|
|
||||||
|
|
||||||
if not ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
bucket = domains[domain]
|
|
||||||
self._align_bucket(bucket, result, ids, is_multi)
|
|
||||||
for key, val in result.items():
|
|
||||||
bucket[key].append(val)
|
|
||||||
|
|
||||||
count += 1
|
|
||||||
total_tokens += len(ids)
|
|
||||||
|
|
||||||
if total_tokens >= self.config.output.max_tokens_per_shard:
|
|
||||||
self._flush(domains, shard_idx)
|
|
||||||
domains.clear()
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if total_tokens > 0:
|
|
||||||
self._flush(domains, shard_idx)
|
|
||||||
|
|
||||||
print(f"Done. {count} documents tokenized.")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _primary_ids(result: dict) -> list:
|
|
||||||
"""Return the first list-valued entry in *result* as the primary id
|
|
||||||
sequence for token counting."""
|
|
||||||
for val in result.values():
|
|
||||||
if isinstance(val, list) and val and isinstance(val[0], int):
|
|
||||||
return val
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
|
|
||||||
"""Pad previously-accumulated keys that are missing from *result*."""
|
|
||||||
for key in list(bucket.keys()):
|
|
||||||
if key in result:
|
|
||||||
continue
|
|
||||||
if is_multi:
|
|
||||||
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
|
|
||||||
bucket[key].append(pad)
|
|
||||||
else:
|
|
||||||
bucket[key].append([1] * len(ids))
|
|
||||||
|
|
||||||
def _iter_items(self):
|
|
||||||
for path in self.paths:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
yield json.loads(line)
|
|
||||||
|
|
||||||
def _flush(self, domains, shard_idx):
|
|
||||||
for domain, keys in domains.items():
|
|
||||||
idx = shard_idx[domain]
|
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
|
||||||
|
|
||||||
pp = self.config.preprocessing
|
|
||||||
if pp.packing_strategy != "simple" and "sequence" in keys:
|
|
||||||
plan = pack_sequences(
|
|
||||||
keys["sequence"],
|
|
||||||
pp.max_packed_len,
|
|
||||||
pp.packing_strategy,
|
|
||||||
pp.truncation_mode,
|
|
||||||
)
|
|
||||||
reordered = defaultdict(list)
|
|
||||||
for orig_idx, truncated_len in plan:
|
|
||||||
for k, vals in keys.items():
|
|
||||||
reordered[k].append(
|
|
||||||
_truncate(
|
|
||||||
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
|
|
||||||
)
|
|
||||||
)
|
|
||||||
keys = reordered
|
|
||||||
|
|
||||||
tensors = {}
|
|
||||||
for key, ids_list in keys.items():
|
|
||||||
dt = _STR_TO_DTYPE.get(
|
|
||||||
self.config.output.dtype.get(key, "int32"), torch.int32
|
|
||||||
)
|
|
||||||
tensors[key] = [
|
|
||||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
|
||||||
]
|
|
||||||
|
|
||||||
pid_mode = self.config.output.position_ids_mode
|
|
||||||
if pid_mode and pid_mode != "none" and "sequence" in tensors:
|
|
||||||
pos_ids = []
|
|
||||||
if pid_mode == "doc_reset":
|
|
||||||
for item in keys["sequence"]:
|
|
||||||
pos_ids.extend(range(len(item)))
|
|
||||||
else:
|
|
||||||
total = sum(len(item) for item in keys["sequence"])
|
|
||||||
pos_ids = list(range(total))
|
|
||||||
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
|
||||||
|
|
||||||
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
|
|
||||||
fmt = self.config.output.storage_format
|
|
||||||
if fmt == "bin":
|
|
||||||
save_bin(shard_path, tensors)
|
|
||||||
else:
|
|
||||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
|
||||||
shard_idx[domain] = idx + 1
|
|
||||||
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
|
|
||||||
tqdm.tqdm.write(
|
|
||||||
f" saved {domain}/shard_{idx:04d} "
|
|
||||||
f"({tensors[first_key][0].numel():,} tokens)"
|
|
||||||
)
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
|
||||||
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class OptimizerProtocol(Protocol):
|
|
||||||
def step(self, closure=None): ...
|
|
||||||
def zero_grad(self): ...
|
|
||||||
@property
|
|
||||||
def param_groups(self) -> Any: ...
|
|
||||||
def state_dict(self) -> dict: ...
|
|
||||||
def load_state_dict(self, d: dict): ...
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class SchedulerProtocol(Protocol):
|
|
||||||
def step(self): ...
|
|
||||||
def state_dict(self) -> dict: ...
|
|
||||||
def load_state_dict(self, d: dict): ...
|
|
||||||
def get_last_lr(self): ...
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -11,172 +9,75 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from astrai.parallel.setup import get_rank
|
from astrai.parallel.setup import get_rank
|
||||||
|
|
||||||
_META_FILE = "meta.json"
|
|
||||||
_CONFIG_FILE = "config.json"
|
|
||||||
_WEIGHTS_FILE = "model.safetensors"
|
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
|
||||||
st.save_file(state_dict, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
return st.load_file(str(path))
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
if rank == 0:
|
|
||||||
state_dict = st.load_file(str(path))
|
|
||||||
else:
|
|
||||||
state_dict = {}
|
|
||||||
tmp = [state_dict]
|
|
||||||
dist.broadcast_object_list(tmp, src=0)
|
|
||||||
return tmp[0]
|
|
||||||
|
|
||||||
|
|
||||||
def save_json(data: dict, path: Union[str, Path]):
|
|
||||||
with open(str(path), "w") as f:
|
|
||||||
json.dump(data, f, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
with open(str(path), "r") as f:
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
if rank == 0:
|
|
||||||
with open(str(path), "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
else:
|
|
||||||
data = {}
|
|
||||||
tmp = [data]
|
|
||||||
dist.broadcast_object_list(tmp, src=0)
|
|
||||||
return tmp[0]
|
|
||||||
|
|
||||||
|
|
||||||
def save_torch(obj: Any, path: Union[str, Path]):
|
|
||||||
torch.save(obj, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
path = Path(path)
|
|
||||||
rank = get_rank()
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
raw = f.read()
|
|
||||||
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
|
||||||
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
|
||||||
else:
|
|
||||||
num_bytes = torch.tensor([0], dtype=torch.long)
|
|
||||||
|
|
||||||
dist.broadcast(num_bytes, src=0)
|
|
||||||
|
|
||||||
if rank != 0:
|
|
||||||
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
|
||||||
|
|
||||||
dist.broadcast(data_tensor, src=0)
|
|
||||||
|
|
||||||
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
|
||||||
return torch.load(buf, map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(config: dict, state_dict: dict, save_directory: str):
|
|
||||||
save_path = Path(save_directory)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_json(config, save_path / _CONFIG_FILE)
|
|
||||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(save_directory: str) -> dict:
|
|
||||||
return load_json(Path(save_directory) / _CONFIG_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(save_directory: str) -> dict:
|
|
||||||
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
|
||||||
path = Path(path)
|
|
||||||
if not broadcast or not dist.is_initialized():
|
|
||||||
return load_safetensors(path)
|
|
||||||
|
|
||||||
rank = get_rank()
|
|
||||||
if rank == 0:
|
|
||||||
state_dict = load_safetensors(path)
|
|
||||||
specs = [
|
|
||||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
|
||||||
for k in sorted(state_dict)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
state_dict = {}
|
|
||||||
specs = []
|
|
||||||
|
|
||||||
specs_list = [specs]
|
|
||||||
dist.broadcast_object_list(specs_list, src=0)
|
|
||||||
specs = specs_list[0]
|
|
||||||
|
|
||||||
for key, shape, dtype_name in specs:
|
|
||||||
dtype = getattr(torch, dtype_name)
|
|
||||||
if rank != 0:
|
|
||||||
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
|
||||||
else:
|
|
||||||
tensor = state_dict[key].contiguous().cpu()
|
|
||||||
dist.broadcast(tensor, src=0)
|
|
||||||
if rank != 0:
|
|
||||||
state_dict[key] = tensor
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
state_dict: Dict[str, Any] = field(default_factory=dict)
|
def __init__(
|
||||||
epoch: int = 0
|
self,
|
||||||
iteration: int = 0
|
state_dict: Dict[str, Any],
|
||||||
extra: Dict[str, Any] = field(default_factory=dict)
|
epoch: int = 0,
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
iteration: int = 0,
|
||||||
config: Dict[str, Any] = field(default_factory=dict)
|
extra: Optional[Dict[str, Any]] = None,
|
||||||
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
self.state_dict = state_dict
|
||||||
|
self.epoch = epoch
|
||||||
|
self.iteration = iteration
|
||||||
|
self.extra = extra or {}
|
||||||
|
self.meta = meta or {}
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
def save(self, save_dir: str):
|
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if get_rank() != 0:
|
rank = get_rank()
|
||||||
return
|
if rank == 0:
|
||||||
|
meta = {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"iteration": self.iteration,
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
|
}
|
||||||
|
meta.update(self.meta)
|
||||||
|
with open(save_path / "meta.json", "w") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
meta = {
|
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||||
"epoch": self.epoch,
|
if self.extra:
|
||||||
"iteration": self.iteration,
|
for key, value in self.extra.items():
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
torch.save(value, save_path / f"{key}.pt")
|
||||||
**self.meta,
|
|
||||||
}
|
|
||||||
save_json(meta, save_path / _META_FILE)
|
|
||||||
save_json(self.config, save_path / _CONFIG_FILE)
|
|
||||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
|
||||||
for key, value in self.extra.items():
|
|
||||||
save_torch(value, save_path / f"{key}.pt")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
def load(
|
||||||
|
cls,
|
||||||
|
save_dir: str,
|
||||||
|
) -> "Checkpoint":
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = load_json(save_path / _META_FILE, broadcast)
|
meta = {}
|
||||||
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
if rank == 0:
|
||||||
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
meta_list = [meta]
|
||||||
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
|
meta = meta_list[0]
|
||||||
|
|
||||||
|
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||||
|
|
||||||
extra = {}
|
extra = {}
|
||||||
for f in sorted(save_path.iterdir()):
|
for f in save_path.iterdir():
|
||||||
if f.suffix == ".pt":
|
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||||
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta.get("epoch", 0),
|
epoch=meta["epoch"],
|
||||||
iteration=meta.get("iteration", 0),
|
iteration=meta["iteration"],
|
||||||
extra=extra,
|
extra=extra or None,
|
||||||
config=config,
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
|
# Message type for chat messages
|
||||||
type MessageType = Dict[str, Any]
|
type MessageType = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ChatTemplate:
|
class ChatTemplate:
|
||||||
"""A chat template with Jinja2 rendering support.
|
"""A chat template with Jinja2 rendering support.
|
||||||
|
|
||||||
|
|
@ -12,24 +15,23 @@ class ChatTemplate:
|
||||||
name: Unique identifier for the template.
|
name: Unique identifier for the template.
|
||||||
template_str: Jinja2 template string.
|
template_str: Jinja2 template string.
|
||||||
description: Optional description.
|
description: Optional description.
|
||||||
default_variables: Optional dictionary of default variable values.
|
default_variables: Optional dictionary of default variable values
|
||||||
|
that will be passed to the template if not overridden during rendering.
|
||||||
special_tokens: Optional dictionary mapping token names to their string values.
|
special_tokens: Optional dictionary mapping token names to their string values.
|
||||||
|
These tokens are automatically added to the template variables.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
name: str
|
||||||
self,
|
template_str: str
|
||||||
name: str = "",
|
description: str = ""
|
||||||
template_str: str = "",
|
default_variables: Dict[str, Any] = None
|
||||||
description: str = "",
|
special_tokens: Dict[str, str] = None
|
||||||
default_variables: Optional[Dict[str, Any]] = None,
|
|
||||||
special_tokens: Optional[Dict[str, str]] = None,
|
def __post_init__(self):
|
||||||
):
|
if self.default_variables is None:
|
||||||
self.name = name
|
self.default_variables = {}
|
||||||
self.template_str = template_str
|
if self.special_tokens is None:
|
||||||
self.description = description
|
self.special_tokens = {}
|
||||||
self.default_variables = default_variables or {}
|
|
||||||
self.special_tokens = special_tokens or {}
|
|
||||||
self._compiled: Template = Template(template_str)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(
|
def from_string(
|
||||||
|
|
@ -41,7 +43,7 @@ class ChatTemplate:
|
||||||
) -> "ChatTemplate":
|
) -> "ChatTemplate":
|
||||||
"""Create a ChatTemplate instance directly from a template string."""
|
"""Create a ChatTemplate instance directly from a template string."""
|
||||||
return cls(
|
return cls(
|
||||||
name="",
|
name="", # empty name for ad‑hoc templates
|
||||||
template_str=template_str,
|
template_str=template_str,
|
||||||
description=description,
|
description=description,
|
||||||
default_variables=default_variables,
|
default_variables=default_variables,
|
||||||
|
|
@ -71,4 +73,5 @@ class ChatTemplate:
|
||||||
if system_prompt is not None:
|
if system_prompt is not None:
|
||||||
variables["system_prompt"] = system_prompt
|
variables["system_prompt"] = system_prompt
|
||||||
|
|
||||||
return self._compiled.render(**variables)
|
jinja_template = Template(self.template_str)
|
||||||
|
return jinja_template.render(**variables)
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,17 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||||
assert G.ndim == 2
|
assert G.ndim == 2
|
||||||
X = G
|
X = G.bfloat16()
|
||||||
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||||
X = X / (X.norm() + 1e-7) * scale
|
X = X / (X.norm() + 1e-7) * scale
|
||||||
if steps == 0:
|
if steps == 0:
|
||||||
return X
|
return X.type_as(G)
|
||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
A = X @ X.T
|
A = X @ X.T
|
||||||
B = A @ X
|
B = A @ X
|
||||||
X = a * X + b * B + c * (A @ B)
|
X = a * X + b * B + c * (A @ B)
|
||||||
return X
|
return X.type_as(G)
|
||||||
|
|
||||||
|
|
||||||
class Muon(Optimizer):
|
class Muon(Optimizer):
|
||||||
|
|
@ -50,94 +50,64 @@ class Muon(Optimizer):
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
params_2d, params_1d = [], []
|
|
||||||
grads_2d, grads_1d = [], []
|
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
if p.grad.is_sparse:
|
grad = p.grad
|
||||||
|
if grad.is_sparse:
|
||||||
raise RuntimeError("Muon does not support sparse gradients")
|
raise RuntimeError("Muon does not support sparse gradients")
|
||||||
if p.ndim >= 2:
|
if p.ndim >= 2:
|
||||||
params_2d.append(p)
|
self._muon_update(p, grad, group)
|
||||||
grads_2d.append(p.grad)
|
|
||||||
else:
|
else:
|
||||||
params_1d.append(p)
|
self._adamw_update(p, grad, group)
|
||||||
grads_1d.append(p.grad)
|
|
||||||
|
|
||||||
if params_2d:
|
|
||||||
self._muon_update_foreach(params_2d, grads_2d, group)
|
|
||||||
if params_1d:
|
|
||||||
self._adamw_update_foreach(params_1d, grads_1d, group)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
def _muon_update(self, p, grad, group):
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
momentum = group["momentum"]
|
momentum = group["momentum"]
|
||||||
wd = group["weight_decay"]
|
wd = group["weight_decay"]
|
||||||
nesterov = group["nesterov"]
|
nesterov = group["nesterov"]
|
||||||
ns_steps = group["ns_steps"]
|
ns_steps = group["ns_steps"]
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
if wd != 0:
|
p.mul_(1 - lr * wd)
|
||||||
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
|
||||||
|
|
||||||
if nesterov:
|
if nesterov:
|
||||||
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
grad = grad.add(p, alpha=wd)
|
||||||
|
|
||||||
bufs = []
|
if "momentum_buffer" not in state:
|
||||||
for p, grad in zip(params_2d, grads_2d):
|
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||||
state = self.state[p]
|
buf = state["momentum_buffer"]
|
||||||
if "momentum_buffer" not in state:
|
buf.lerp_(grad, 1 - momentum)
|
||||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
|
||||||
bufs.append(state["momentum_buffer"])
|
|
||||||
|
|
||||||
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||||
|
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||||
|
p.add_(update, alpha=-lr * scale)
|
||||||
|
|
||||||
for p, buf in zip(params_2d, bufs):
|
def _adamw_update(self, p, grad, group):
|
||||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
|
||||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
|
||||||
p.add_(update, alpha=-lr * scale)
|
|
||||||
|
|
||||||
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
|
||||||
lr = group["adamw_lr"]
|
lr = group["adamw_lr"]
|
||||||
betas = group["adamw_betas"]
|
betas = group["adamw_betas"]
|
||||||
eps = group["adamw_eps"]
|
eps = group["adamw_eps"]
|
||||||
wd = group["adamw_wd"]
|
wd = group["adamw_wd"]
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
steps: list[int] = []
|
if not state:
|
||||||
exp_avgs, exp_avg_sqs = [], []
|
state["step"] = 0
|
||||||
has_state = []
|
state["exp_avg"] = torch.zeros_like(p)
|
||||||
for p in params_1d:
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
state = self.state[p]
|
|
||||||
if not state:
|
|
||||||
state["step"] = 0
|
|
||||||
state["exp_avg"] = torch.zeros_like(p)
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
||||||
has_state.append(False)
|
|
||||||
else:
|
|
||||||
has_state.append(True)
|
|
||||||
state["step"] += 1
|
|
||||||
steps.append(state["step"])
|
|
||||||
exp_avgs.append(state["exp_avg"])
|
|
||||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
|
||||||
|
|
||||||
|
state["step"] += 1
|
||||||
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
beta1, beta2 = betas
|
beta1, beta2 = betas
|
||||||
|
|
||||||
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
exp_avg.lerp_(grad, 1 - beta1)
|
||||||
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
|
||||||
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
|
||||||
|
|
||||||
bias_correction1 = [1 - beta1**s for s in steps]
|
step = state["step"]
|
||||||
bias_correction2 = [1 - beta2**s for s in steps]
|
bias1 = 1 - beta1**step
|
||||||
|
bias2 = 1 - beta2**step
|
||||||
|
|
||||||
if wd != 0:
|
p.mul_(1 - lr * wd)
|
||||||
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
|
||||||
|
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
|
||||||
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
|
||||||
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
|
||||||
denom = torch._foreach_sqrt(denom)
|
|
||||||
torch._foreach_add_(denom, eps)
|
|
||||||
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
if not issubclass(scheduler_cls, BaseScheduler):
|
||||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Training strategy implementations with factory pattern."""
|
"""Training strategy implementations with factory pattern."""
|
||||||
|
|
||||||
|
import copy
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
|
|
@ -7,14 +8,26 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
"""Create a frozen reference model from model_fn + full state dict."""
|
"""Unwrap DDP wrapper if present to get the original model."""
|
||||||
ref_model = model_fn()
|
if isinstance(model, DDP):
|
||||||
ref_model.load_state_dict(state_dict)
|
return model.module
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||||
|
"""Create a reference model for DPO/GRPO training.
|
||||||
|
|
||||||
|
Handles DDP-wrapped models safely by unwrapping first,
|
||||||
|
then creating a deep copy with frozen gradients.
|
||||||
|
"""
|
||||||
|
original_model = unwrap_model(model)
|
||||||
|
ref_model = copy.deepcopy(original_model)
|
||||||
ref_model.requires_grad_(False)
|
ref_model.requires_grad_(False)
|
||||||
ref_model.eval()
|
ref_model.eval()
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
@ -68,22 +81,6 @@ def get_logprobs(
|
||||||
return token_logprobs * shifted_mask
|
return token_logprobs * shifted_mask
|
||||||
|
|
||||||
|
|
||||||
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
|
|
||||||
S = position_ids.size(1)
|
|
||||||
device = position_ids.device
|
|
||||||
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
|
|
||||||
doc_ids = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
|
|
||||||
boundaries.long().cumsum(dim=1),
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
|
|
||||||
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
|
|
||||||
return (same_doc & causal).unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
"""Abstract base class for training strategies."""
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
|
|
@ -92,8 +89,6 @@ class BaseStrategy(ABC):
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
self.executor = kwargs.pop("executor", None)
|
|
||||||
self.model_fn = kwargs.pop("model_fn", None)
|
|
||||||
self.extra_kwargs = kwargs
|
self.extra_kwargs = kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -128,7 +123,7 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, strategy_cls: type):
|
def _validate_component(cls, strategy_cls: type) -> None:
|
||||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
if not issubclass(strategy_cls, BaseStrategy):
|
||||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||||
|
|
@ -196,19 +191,15 @@ class SFTStrategy(BaseStrategy):
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
input_ids, target_ids, position_ids, loss_mask = (
|
input_ids, target_ids, loss_mask = (
|
||||||
batch["input_ids"],
|
batch["input_ids"],
|
||||||
batch["target_ids"],
|
batch["target_ids"],
|
||||||
batch["position_ids"],
|
|
||||||
batch["loss_mask"],
|
batch["loss_mask"],
|
||||||
)
|
)
|
||||||
|
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
input_mask = make_doc_boundary_mask(position_ids)
|
logits = self.model(input_ids=input_ids)["logits"]
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
logits = self.model(
|
|
||||||
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
|
|
||||||
)["logits"]
|
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
|
|
@ -237,9 +228,7 @@ class DPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(
|
self.ref_model = create_ref_model(model)
|
||||||
self.model_fn, self.executor.unwrap_model(model)
|
|
||||||
).to(device=self.device)
|
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
|
|
@ -293,9 +282,7 @@ class GRPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(
|
self.ref_model = create_ref_model(model)
|
||||||
self.model_fn, self.executor.unwrap_model(model)
|
|
||||||
).to(device=self.device)
|
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
|
@ -305,7 +292,8 @@ class GRPOStrategy(BaseStrategy):
|
||||||
|
|
||||||
def sync_ref_model(self):
|
def sync_ref_model(self):
|
||||||
"""Copy current model weights to ref model."""
|
"""Copy current model weights to ref model."""
|
||||||
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
ref_state = self.model.state_dict()
|
||||||
|
self.ref_model.load_state_dict(ref_state)
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
self._step += 1
|
self._step += 1
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.parallel import only_on_rank
|
from astrai.parallel import only_on_rank
|
||||||
from astrai.parallel.setup import get_current_device, get_rank
|
from astrai.parallel.setup import get_current_device
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.metric_util import (
|
from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_max,
|
ctx_get_grad_max,
|
||||||
|
|
@ -51,15 +51,18 @@ class TrainCallback(Protocol):
|
||||||
def on_epoch_end(self, context: TrainContext):
|
def on_epoch_end(self, context: TrainContext):
|
||||||
"""Called at the end of each epoch."""
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
|
def on_step_begin(self, context: TrainContext):
|
||||||
|
"""Called at the beginning of each step."""
|
||||||
|
|
||||||
|
def on_step_end(self, context: TrainContext):
|
||||||
|
"""Called at the end of each step."""
|
||||||
|
|
||||||
def on_batch_begin(self, context: TrainContext):
|
def on_batch_begin(self, context: TrainContext):
|
||||||
"""Called at the beginning of each batch."""
|
"""Called at the beginning of each batch."""
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
"""Called at the end of each batch."""
|
"""Called at the end of each batch."""
|
||||||
|
|
||||||
def on_optimizer_step(self, context: TrainContext):
|
|
||||||
"""Called on every optimizer step (sync step only)."""
|
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
"""Called when an error occurs during training."""
|
"""Called when an error occurs during training."""
|
||||||
|
|
||||||
|
|
@ -85,7 +88,7 @@ class GradientClippingCallback(TrainCallback):
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_optimizer_step(self, context: TrainContext):
|
def on_step_begin(self, context: TrainContext):
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -137,31 +140,44 @@ class CheckpointCallback(TrainCallback):
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
|
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
|
self.state_dict_fn = state_dict_fn
|
||||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||||
|
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
|
@only_on_rank(0)
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
state_dict = context.executor.unwrap_model(context.model)
|
save_path = os.path.join(
|
||||||
|
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||||
|
)
|
||||||
|
state_dict = (
|
||||||
|
self.state_dict_fn(context.model)
|
||||||
|
if self.state_dict_fn
|
||||||
|
else context.model.state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = self.save_extra_fn(context)
|
||||||
|
context.checkpoint = Checkpoint(
|
||||||
|
state_dict=state_dict,
|
||||||
|
epoch=context.epoch,
|
||||||
|
iteration=context.iteration,
|
||||||
|
extra=extra,
|
||||||
|
meta=context.config.to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
if get_rank() == 0:
|
def on_train_begin(self, context: TrainContext):
|
||||||
save_path = os.path.join(
|
if context.checkpoint and context.checkpoint.extra:
|
||||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
self.load_extra_fn(context.checkpoint.extra, context)
|
||||||
)
|
|
||||||
extra = self.save_extra_fn(context)
|
|
||||||
context.checkpoint = Checkpoint(
|
|
||||||
state_dict=state_dict,
|
|
||||||
epoch=context.epoch,
|
|
||||||
iteration=context.iteration,
|
|
||||||
extra=extra,
|
|
||||||
config=context.model_config,
|
|
||||||
)
|
|
||||||
context.checkpoint.save(save_path)
|
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
|
|
@ -183,6 +199,12 @@ class CheckpointCallback(TrainCallback):
|
||||||
extra[name] = obj.state_dict()
|
extra[name] = obj.state_dict()
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_extra(extra: dict, context: TrainContext):
|
||||||
|
for name in CheckpointCallback.extra_keys:
|
||||||
|
if name in extra:
|
||||||
|
getattr(context, name).load_state_dict(extra[name])
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("progress_bar")
|
@CallbackFactory.register("progress_bar")
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
|
|
@ -191,7 +213,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
|
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
||||||
):
|
):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
self.log_interval = log_interval
|
self.log_interval = log_interval
|
||||||
|
|
@ -204,7 +226,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
file=self.file or sys.stdout,
|
file=self.file,
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -322,7 +344,7 @@ class ValidationCallback(TrainCallback):
|
||||||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_optimizer_step(self, context: TrainContext):
|
def on_step_end(self, context: TrainContext):
|
||||||
if context.val_dataloader is None:
|
if context.val_dataloader is None:
|
||||||
return
|
return
|
||||||
cfg = context.config
|
cfg = context.config
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,15 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.dataset import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
from astrai.model.components.lora import inject_lora
|
|
||||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,12 +18,10 @@ class TrainContext:
|
||||||
model: nn.Module = field(default=None)
|
model: nn.Module = field(default=None)
|
||||||
strategy: BaseStrategy = field(default=None)
|
strategy: BaseStrategy = field(default=None)
|
||||||
dataloader: DataLoader = field(default=None)
|
dataloader: DataLoader = field(default=None)
|
||||||
optimizer: OptimizerProtocol = field(default=None)
|
optimizer: Optimizer = field(default=None)
|
||||||
scheduler: SchedulerProtocol = field(default=None)
|
scheduler: LRScheduler = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
config: TrainConfig = field(default=None)
|
config: TrainConfig = field(default=None)
|
||||||
model_config: dict = field(default_factory=dict)
|
|
||||||
executor: BaseExecutor = field(default=None)
|
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
|
|
@ -45,91 +40,49 @@ class TrainContextBuilder:
|
||||||
config: TrainConfig,
|
config: TrainConfig,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._resume_dir: Optional[str] = None
|
self._checkpoint: Optional[Checkpoint] = None
|
||||||
|
|
||||||
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
self._resume_dir = resume_dir
|
self._checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
cfg = self.config
|
|
||||||
device = get_current_device()
|
|
||||||
|
|
||||||
executor = ExecutorFactory.create(
|
|
||||||
cfg.parallel_mode,
|
|
||||||
grad_accum_steps=cfg.grad_accum_steps,
|
|
||||||
**cfg.executor_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = cfg.model_fn()
|
|
||||||
model = model.to(device=device)
|
|
||||||
|
|
||||||
model_config = {}
|
|
||||||
if self._resume_dir:
|
|
||||||
config_path = Path(self._resume_dir) / "config.json"
|
|
||||||
if config_path.exists():
|
|
||||||
model_config = load_json(config_path)
|
|
||||||
|
|
||||||
if not model_config and hasattr(model, "config"):
|
|
||||||
model_config = model.config.to_dict()
|
|
||||||
|
|
||||||
context = TrainContext(
|
context = TrainContext(
|
||||||
model=model,
|
model=self.config.model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
config=cfg,
|
config=self.config,
|
||||||
model_config=model_config,
|
|
||||||
executor=executor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._resume_dir is not None:
|
device = get_current_device()
|
||||||
resume_path = Path(self._resume_dir)
|
context.model = context.model.to(device=device)
|
||||||
if (resume_path / "meta.json").exists():
|
|
||||||
checkpoint = Checkpoint.load(self._resume_dir)
|
|
||||||
state_dict = checkpoint.state_dict
|
|
||||||
if checkpoint.config:
|
|
||||||
context.model_config = checkpoint.config
|
|
||||||
else:
|
|
||||||
checkpoint = None
|
|
||||||
state_dict = load_model_weights(self._resume_dir)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
if checkpoint is not None:
|
|
||||||
context.epoch = cfg.start_epoch
|
|
||||||
context.iteration = cfg.start_batch
|
|
||||||
context.checkpoint = checkpoint
|
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
||||||
inject_lora(
|
context.model = self.config.parallel_wrapper(context.model)
|
||||||
model,
|
|
||||||
r=cfg.lora.r,
|
if self._checkpoint is not None:
|
||||||
alpha=cfg.lora.alpha,
|
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
||||||
target_modules=set(cfg.lora.target_modules),
|
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
||||||
|
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||||
|
context.checkpoint = self._checkpoint
|
||||||
|
else:
|
||||||
|
context.checkpoint = Checkpoint(
|
||||||
|
state_dict=context.model.state_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
context.optimizer = cfg.optimizer_fn(model)
|
context.optimizer = self.config.optimizer_fn(context.model)
|
||||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
train_dataset = cfg.dataset
|
|
||||||
val_dataset = cfg.val_dataset
|
|
||||||
|
|
||||||
if val_dataset is None and cfg.val_split is not None:
|
|
||||||
n_total = len(cfg.dataset)
|
|
||||||
n_val = max(1, int(n_total * cfg.val_split))
|
|
||||||
n_train = n_total - n_val
|
|
||||||
generator = torch.Generator().manual_seed(cfg.random_seed)
|
|
||||||
train_dataset, val_dataset = random_split(
|
|
||||||
cfg.dataset, [n_train, n_val], generator=generator
|
|
||||||
)
|
|
||||||
|
|
||||||
|
cfg = self.config
|
||||||
sampler_offset = context.iteration * cfg.batch_per_device
|
sampler_offset = context.iteration * cfg.batch_per_device
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
data_source=train_dataset,
|
data_source=cfg.dataset,
|
||||||
start_epoch=context.epoch,
|
start_epoch=context.epoch,
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=cfg.random_seed,
|
seed=cfg.random_seed,
|
||||||
)
|
)
|
||||||
context.dataloader = DataLoader(
|
context.dataloader = DataLoader(
|
||||||
train_dataset,
|
cfg.dataset,
|
||||||
batch_size=cfg.batch_per_device,
|
batch_size=cfg.batch_per_device,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
|
|
@ -137,16 +90,16 @@ class TrainContextBuilder:
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if val_dataset is not None:
|
if cfg.val_dataset is not None:
|
||||||
val_sampler = ResumableDistributedSampler(
|
val_sampler = ResumableDistributedSampler(
|
||||||
data_source=val_dataset,
|
data_source=cfg.val_dataset,
|
||||||
start_epoch=0,
|
start_epoch=0,
|
||||||
start_iter=0,
|
start_iter=0,
|
||||||
seed=cfg.random_seed,
|
seed=cfg.random_seed,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
context.val_dataloader = DataLoader(
|
context.val_dataloader = DataLoader(
|
||||||
val_dataset,
|
cfg.val_dataset,
|
||||||
batch_size=cfg.batch_per_device,
|
batch_size=cfg.batch_per_device,
|
||||||
sampler=val_sampler,
|
sampler=val_sampler,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
|
|
@ -154,30 +107,11 @@ class TrainContextBuilder:
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
|
||||||
executor.prepare(
|
|
||||||
model,
|
|
||||||
context.optimizer,
|
|
||||||
context.dataloader,
|
|
||||||
context.scheduler,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if context.checkpoint and context.checkpoint.extra:
|
|
||||||
extra = context.checkpoint.extra
|
|
||||||
for name in ("optimizer", "scheduler"):
|
|
||||||
if name in extra:
|
|
||||||
obj = getattr(context, name, None)
|
|
||||||
if obj is not None:
|
|
||||||
obj.load_state_dict(extra[name])
|
|
||||||
|
|
||||||
context.strategy = StrategyFactory.create(
|
context.strategy = StrategyFactory.create(
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=cfg.strategy,
|
train_type=self.config.strategy,
|
||||||
device=device,
|
device=device,
|
||||||
executor=executor,
|
**self.config.extra_kwargs,
|
||||||
model_fn=cfg.model_fn,
|
|
||||||
**cfg.extra_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
from astrai.parallel.setup import spawn_parallel_fn
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
CallbackFactory,
|
CallbackFactory,
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
|
|
@ -33,6 +34,7 @@ class Trainer:
|
||||||
"checkpoint",
|
"checkpoint",
|
||||||
cfg.ckpt_dir,
|
cfg.ckpt_dir,
|
||||||
cfg.ckpt_interval,
|
cfg.ckpt_interval,
|
||||||
|
state_dict_fn=cfg.state_dict_fn,
|
||||||
),
|
),
|
||||||
CallbackFactory.create(
|
CallbackFactory.create(
|
||||||
"metric_logger",
|
"metric_logger",
|
||||||
|
|
@ -53,49 +55,47 @@ class Trainer:
|
||||||
if method:
|
if method:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
context = (
|
cfg = self.train_config
|
||||||
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||||
)
|
|
||||||
executor = context.executor
|
|
||||||
self._call_callbacks("on_train_begin", context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
|
grad_accum_steps = cfg.grad_accum_steps
|
||||||
|
|
||||||
for epoch in range(context.epoch, context.config.n_epoch):
|
for epoch in range(context.epoch, cfg.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
self._call_callbacks("on_batch_begin", context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
|
loss = context.strategy(batch)
|
||||||
|
context.loss = loss.item()
|
||||||
|
stand_loss = loss / grad_accum_steps
|
||||||
|
stand_loss.backward()
|
||||||
|
context.iteration += 1
|
||||||
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
with executor.accumulate(context.model):
|
if context.iteration % grad_accum_steps == 0:
|
||||||
loss = context.strategy(batch)
|
self._call_callbacks("on_step_begin", context)
|
||||||
context.loss = loss.item()
|
context.optimizer.step()
|
||||||
stand_loss = loss / executor.grad_accum_steps
|
context.optimizer.zero_grad()
|
||||||
executor.backward(stand_loss)
|
self._call_callbacks("on_step_end", context)
|
||||||
context.iteration += 1
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
|
||||||
|
|
||||||
if executor.sync_gradients:
|
if context.scheduler:
|
||||||
self._call_callbacks("on_optimizer_step", context)
|
context.scheduler.step()
|
||||||
context.optimizer.step()
|
|
||||||
context.optimizer.zero_grad()
|
|
||||||
|
|
||||||
if context.scheduler:
|
|
||||||
context.scheduler.step()
|
|
||||||
|
|
||||||
self._call_callbacks("on_epoch_end", context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Training failed: %s", str(e), exc_info=True)
|
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||||
self._call_callbacks("on_error", context)
|
self._call_callbacks("on_error", context)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks("on_train_end", context)
|
self._call_callbacks("on_train_end", context)
|
||||||
|
|
||||||
def train(self, resume_dir: Optional[str] = None):
|
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(
|
||||||
self._trainer_loop,
|
self._trainer_loop,
|
||||||
|
|
@ -105,5 +105,5 @@ class Trainer:
|
||||||
master_port=cfg.master_port,
|
master_port=cfg.master_port,
|
||||||
device_type=cfg.device_type,
|
device_type=cfg.device_type,
|
||||||
start_method=cfg.start_method,
|
start_method=cfg.start_method,
|
||||||
resume_dir=resume_dir,
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,336 +0,0 @@
|
||||||
"""HumanEval code generation benchmark.
|
|
||||||
|
|
||||||
Generates n completions per problem, extracts function bodies, executes
|
|
||||||
against hidden tests, and computes pass@k.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
python scripts/tools/evaluate_humaneval.py --param_path ./params \
|
|
||||||
--data_path HumanEval.jsonl.gz --output results.json \
|
|
||||||
--num_samples 200 --temperature 0.8 --max_tokens 512
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
from math import prod
|
|
||||||
from multiprocessing import Process, Queue
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from astrai.inference import InferenceEngine
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
HUMANEVAL_URL = (
|
|
||||||
"https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
_STOP_SEQUENCES = [
|
|
||||||
"\nclass ",
|
|
||||||
"\ndef ",
|
|
||||||
"\n# ",
|
|
||||||
"\nif __name__",
|
|
||||||
"\nprint(",
|
|
||||||
"\n\n\n",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _download_humaneval(data_path: str):
|
|
||||||
if os.path.exists(data_path):
|
|
||||||
return
|
|
||||||
import gzip
|
|
||||||
import urllib.request
|
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
|
|
||||||
print(f"Downloading HumanEval from {HUMANEVAL_URL} ...")
|
|
||||||
tmp = data_path + ".tmp"
|
|
||||||
urllib.request.urlretrieve(HUMANEVAL_URL, tmp)
|
|
||||||
with gzip.open(tmp, "rb") as f_in:
|
|
||||||
with open(data_path, "wb") as f_out:
|
|
||||||
f_out.write(f_in.read())
|
|
||||||
os.remove(tmp)
|
|
||||||
print(f" saved to {data_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_problems(data_path: str) -> List[dict]:
|
|
||||||
problems = []
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line:
|
|
||||||
problems.append(json.loads(line))
|
|
||||||
return problems
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_function_body(code: str, entry_point: str) -> Optional[str]:
|
|
||||||
"""Extract the function body from a completion."""
|
|
||||||
pattern = rf"def\s+{re.escape(entry_point)}\b[^:]*:"
|
|
||||||
match = re.search(pattern, code)
|
|
||||||
if not match:
|
|
||||||
# Use the full code as-is if we can't find the function
|
|
||||||
return code
|
|
||||||
|
|
||||||
body_start = match.end()
|
|
||||||
lines = code[body_start:].split("\n")
|
|
||||||
body_lines = []
|
|
||||||
started = False
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
stripped = line.rstrip()
|
|
||||||
if not stripped and not started:
|
|
||||||
continue
|
|
||||||
if not stripped and started:
|
|
||||||
body_lines.append("")
|
|
||||||
continue
|
|
||||||
if not started:
|
|
||||||
started = True
|
|
||||||
if stripped.lstrip() == stripped and started:
|
|
||||||
break
|
|
||||||
body_lines.append(stripped)
|
|
||||||
|
|
||||||
body = "\n".join(body_lines)
|
|
||||||
if not body.strip():
|
|
||||||
return None
|
|
||||||
return body
|
|
||||||
|
|
||||||
|
|
||||||
def _trim_stop_sequences(text: str) -> str:
|
|
||||||
for stop in _STOP_SEQUENCES:
|
|
||||||
idx = text.find(stop)
|
|
||||||
if idx != -1:
|
|
||||||
text = text[:idx]
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_code(problem: dict, completion: str, timeout: float = 3.0) -> bool:
|
|
||||||
"""Run the completion against hidden tests in a subprocess."""
|
|
||||||
|
|
||||||
def _worker(queue, full_code):
|
|
||||||
try:
|
|
||||||
namespace = {}
|
|
||||||
exec(full_code, namespace)
|
|
||||||
check = namespace.get("check")
|
|
||||||
if check is None:
|
|
||||||
queue.put(False)
|
|
||||||
return
|
|
||||||
check(namespace.get(problem["entry_point"]))
|
|
||||||
queue.put(True)
|
|
||||||
except Exception:
|
|
||||||
queue.put(False)
|
|
||||||
|
|
||||||
full_code = problem["prompt"] + completion + "\n" + problem["test"]
|
|
||||||
|
|
||||||
queue: Queue = Queue()
|
|
||||||
proc = Process(target=_worker, args=(queue, full_code))
|
|
||||||
proc.start()
|
|
||||||
proc.join(timeout)
|
|
||||||
|
|
||||||
if proc.is_alive():
|
|
||||||
proc.terminate()
|
|
||||||
proc.join()
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
return queue.get_nowait()
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _pass_at_k(n: int, c: int, k: int) -> float:
|
|
||||||
"""Unbiased estimator of pass@k."""
|
|
||||||
if n - c < k:
|
|
||||||
return 1.0
|
|
||||||
return 1.0 - float(prod(1.0 - k / np.arange(n - c + 1, n + 1)))
|
|
||||||
|
|
||||||
|
|
||||||
def _deduplicate(completions: List[str]) -> List[str]:
|
|
||||||
seen = set()
|
|
||||||
unique = []
|
|
||||||
for c in completions:
|
|
||||||
if c not in seen:
|
|
||||||
seen.add(c)
|
|
||||||
unique.append(c)
|
|
||||||
return unique
|
|
||||||
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
engine: InferenceEngine,
|
|
||||||
prompt: str,
|
|
||||||
num_samples: int,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
batch_size: int,
|
|
||||||
) -> List[str]:
|
|
||||||
batches = [prompt] * min(batch_size, num_samples)
|
|
||||||
completions = []
|
|
||||||
remaining = num_samples
|
|
||||||
|
|
||||||
while remaining > 0:
|
|
||||||
current = min(batch_size, remaining)
|
|
||||||
batch_prompts = batches[:current]
|
|
||||||
outputs = engine.generate(
|
|
||||||
prompt=batch_prompts,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
)
|
|
||||||
if isinstance(outputs, str):
|
|
||||||
outputs = [outputs]
|
|
||||||
completions.extend(outputs)
|
|
||||||
remaining -= current
|
|
||||||
|
|
||||||
return _deduplicate(completions)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
engine: InferenceEngine,
|
|
||||||
problems: List[dict],
|
|
||||||
num_samples: int,
|
|
||||||
max_tokens: int,
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
top_k: int,
|
|
||||||
batch_size: int,
|
|
||||||
k_values: Tuple[int, ...] = (1, 10, 100),
|
|
||||||
) -> Dict:
|
|
||||||
results = {}
|
|
||||||
all_pass_at_k = {k: [] for k in k_values}
|
|
||||||
|
|
||||||
for problem in tqdm.tqdm(problems, desc="HumanEval", unit="problem"):
|
|
||||||
task_id = problem["task_id"]
|
|
||||||
prompt = problem["prompt"]
|
|
||||||
entry_point = problem["entry_point"]
|
|
||||||
|
|
||||||
raw_completions = _generate(
|
|
||||||
engine,
|
|
||||||
prompt,
|
|
||||||
num_samples,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
top_k,
|
|
||||||
batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
completions = []
|
|
||||||
for raw in raw_completions:
|
|
||||||
trimmed = _trim_stop_sequences(raw)
|
|
||||||
body = _extract_function_body(trimmed, entry_point)
|
|
||||||
if body:
|
|
||||||
completions.append(body)
|
|
||||||
|
|
||||||
passed = 0
|
|
||||||
for comp in completions:
|
|
||||||
if _execute_code(problem, comp):
|
|
||||||
passed += 1
|
|
||||||
|
|
||||||
n = len(completions)
|
|
||||||
c = passed
|
|
||||||
result = {"task_id": task_id, "n": n, "passed": c}
|
|
||||||
for k in k_values:
|
|
||||||
result[f"pass@{k}"] = round(_pass_at_k(n, c, k), 4)
|
|
||||||
all_pass_at_k[k].append(_pass_at_k(n, c, k))
|
|
||||||
results[task_id] = result
|
|
||||||
|
|
||||||
summary = {}
|
|
||||||
for k in k_values:
|
|
||||||
vals = all_pass_at_k[k]
|
|
||||||
summary[f"pass@{k}"] = round(float(np.mean(vals)), 4)
|
|
||||||
results["_summary"] = summary
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="HumanEval benchmark")
|
|
||||||
parser.add_argument(
|
|
||||||
"--param_path", type=str, default="./params", help="Model directory"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--data_path",
|
|
||||||
type=str,
|
|
||||||
default="./humaneval/HumanEval.jsonl",
|
|
||||||
help="HumanEval JSONL file (auto-download if missing)",
|
|
||||||
)
|
|
||||||
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_samples",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="Completions per problem",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_tokens", type=int, default=512, help="Max generation tokens"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--temperature", type=float, default=0.8, help="Sampling temperature"
|
|
||||||
)
|
|
||||||
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
|
|
||||||
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_size", type=int, default=1, help="Inference batch size"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--problems",
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
default=None,
|
|
||||||
help="Specific problem indices (0-based)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
_download_humaneval(args.data_path)
|
|
||||||
problems = _load_problems(args.data_path)
|
|
||||||
if args.problems:
|
|
||||||
problems = [problems[i] for i in args.problems if i < len(problems)]
|
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(args.param_path)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
engine = InferenceEngine(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_batch_size=args.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = evaluate(
|
|
||||||
engine=engine,
|
|
||||||
problems=problems,
|
|
||||||
num_samples=args.num_samples,
|
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
temperature=args.temperature,
|
|
||||||
top_p=args.top_p,
|
|
||||||
top_k=args.top_k,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
k_values=(1, 10, 100),
|
|
||||||
)
|
|
||||||
|
|
||||||
summary = results.pop("_summary")
|
|
||||||
print(f"\n{'=' * 60}")
|
|
||||||
for k, v in summary.items():
|
|
||||||
print(f" {k}: {v:.2%}")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
|
|
||||||
if args.output:
|
|
||||||
results["_summary"] = summary
|
|
||||||
with open(args.output, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
|
||||||
print(f"Results saved to {args.output}")
|
|
||||||
|
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,319 +0,0 @@
|
||||||
"""MMLU evaluation via log-likelihood ranking."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import tarfile
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
|
|
||||||
MMLU_SUBJECTS = [
|
|
||||||
"abstract_algebra",
|
|
||||||
"anatomy",
|
|
||||||
"astronomy",
|
|
||||||
"business_ethics",
|
|
||||||
"clinical_knowledge",
|
|
||||||
"college_biology",
|
|
||||||
"college_chemistry",
|
|
||||||
"college_computer_science",
|
|
||||||
"college_mathematics",
|
|
||||||
"college_medicine",
|
|
||||||
"college_physics",
|
|
||||||
"computer_security",
|
|
||||||
"conceptual_physics",
|
|
||||||
"econometrics",
|
|
||||||
"electrical_engineering",
|
|
||||||
"elementary_mathematics",
|
|
||||||
"formal_logic",
|
|
||||||
"global_facts",
|
|
||||||
"high_school_biology",
|
|
||||||
"high_school_chemistry",
|
|
||||||
"high_school_computer_science",
|
|
||||||
"high_school_european_history",
|
|
||||||
"high_school_geography",
|
|
||||||
"high_school_government_and_politics",
|
|
||||||
"high_school_macroeconomics",
|
|
||||||
"high_school_mathematics",
|
|
||||||
"high_school_microeconomics",
|
|
||||||
"high_school_physics",
|
|
||||||
"high_school_psychology",
|
|
||||||
"high_school_statistics",
|
|
||||||
"high_school_us_history",
|
|
||||||
"high_school_world_history",
|
|
||||||
"human_aging",
|
|
||||||
"human_sexuality",
|
|
||||||
"international_law",
|
|
||||||
"jurisprudence",
|
|
||||||
"logical_fallacies",
|
|
||||||
"machine_learning",
|
|
||||||
"management",
|
|
||||||
"marketing",
|
|
||||||
"medical_genetics",
|
|
||||||
"miscellaneous",
|
|
||||||
"moral_disputes",
|
|
||||||
"moral_scenarios",
|
|
||||||
"nutrition",
|
|
||||||
"philosophy",
|
|
||||||
"prehistory",
|
|
||||||
"professional_accounting",
|
|
||||||
"professional_law",
|
|
||||||
"professional_medicine",
|
|
||||||
"professional_psychology",
|
|
||||||
"public_relations",
|
|
||||||
"security_studies",
|
|
||||||
"sociology",
|
|
||||||
"us_foreign_policy",
|
|
||||||
"virology",
|
|
||||||
"world_religions",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _download_and_extract(url: str, data_dir: str):
|
|
||||||
tar_path = os.path.join(data_dir, "data.tar")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
print(f"Downloading MMLU data from {url}...")
|
|
||||||
resp = requests.get(url, stream=True, timeout=300)
|
|
||||||
resp.raise_for_status()
|
|
||||||
total = int(resp.headers.get("content-length", 0))
|
|
||||||
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
|
|
||||||
with open(tar_path, "wb") as f:
|
|
||||||
for chunk in resp.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
bar.update(len(chunk))
|
|
||||||
print("Extracting...")
|
|
||||||
with tarfile.open(tar_path, "r") as tf:
|
|
||||||
tf.extractall(data_dir)
|
|
||||||
os.remove(tar_path)
|
|
||||||
|
|
||||||
|
|
||||||
def download_mmlu(data_dir: str):
|
|
||||||
_download_and_extract(MMLU_URL, data_dir)
|
|
||||||
src = os.path.join(data_dir, "data")
|
|
||||||
if os.path.exists(src):
|
|
||||||
for item in os.listdir(src):
|
|
||||||
src_item = os.path.join(src, item)
|
|
||||||
dst_item = os.path.join(data_dir, item)
|
|
||||||
if os.path.exists(dst_item):
|
|
||||||
if os.path.isdir(dst_item):
|
|
||||||
shutil.rmtree(dst_item)
|
|
||||||
else:
|
|
||||||
os.remove(dst_item)
|
|
||||||
os.rename(src_item, dst_item)
|
|
||||||
os.rmdir(src)
|
|
||||||
print(f"MMLU data saved to {data_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_prefix(text: str, prefix: str) -> str:
|
|
||||||
if text.startswith(prefix):
|
|
||||||
return text[len(prefix) :].strip()
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def load_csv(path: str) -> list[dict]:
|
|
||||||
data = []
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
for row in csv.reader(f):
|
|
||||||
if len(row) < 6:
|
|
||||||
continue
|
|
||||||
if row[0].strip().lower() == "question":
|
|
||||||
continue
|
|
||||||
data.append(
|
|
||||||
{
|
|
||||||
"question": row[0].strip(),
|
|
||||||
"A": _strip_prefix(row[1].strip(), "A)"),
|
|
||||||
"B": _strip_prefix(row[2].strip(), "B)"),
|
|
||||||
"C": _strip_prefix(row[3].strip(), "C)"),
|
|
||||||
"D": _strip_prefix(row[4].strip(), "D)"),
|
|
||||||
"answer": row[5].strip(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(
|
|
||||||
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
|
|
||||||
) -> str:
|
|
||||||
prompt = ""
|
|
||||||
if n_shot > 0 and dev_data:
|
|
||||||
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
|
|
||||||
for item in dev_data[:n_shot]:
|
|
||||||
prompt += f"Question: {item['question']}\n"
|
|
||||||
for k in ("A", "B", "C", "D"):
|
|
||||||
prompt += f"{k}. {item[k]}\n"
|
|
||||||
prompt += f"Answer: {item['answer']}\n\n"
|
|
||||||
prompt += f"Question: {question}\n"
|
|
||||||
for k in ("A", "B", "C", "D"):
|
|
||||||
prompt += f"{k}. {choices[k]}\n"
|
|
||||||
prompt += "Answer:"
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def apply_chat(
|
|
||||||
tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None
|
|
||||||
) -> str:
|
|
||||||
"""Wrap raw MMLU prompt in the model's chat template format.
|
|
||||||
|
|
||||||
For few-shot, prepend example Q&A pairs as a second user/assistant exchange.
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
if n_shot > 0 and dev_data:
|
|
||||||
for item in dev_data[:n_shot]:
|
|
||||||
q = f"Question: {item['question']}\n"
|
|
||||||
for k in ("A", "B", "C", "D"):
|
|
||||||
q += f"{k}. {item[k]}\n"
|
|
||||||
q += "Answer:"
|
|
||||||
messages.append({"role": "user", "content": q})
|
|
||||||
messages.append({"role": "assistant", "content": item["answer"]})
|
|
||||||
messages.append({"role": "user", "content": raw_prompt})
|
|
||||||
return tokenizer.apply_chat_template(
|
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def choice_logprob(
|
|
||||||
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
|
|
||||||
) -> float:
|
|
||||||
choice_text = choice_letter
|
|
||||||
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
|
|
||||||
input_ids = context_ids + choice_ids
|
|
||||||
max_len = model.config.max_len
|
|
||||||
if len(input_ids) > max_len:
|
|
||||||
overflow = len(input_ids) - max_len
|
|
||||||
input_ids = input_ids[overflow:]
|
|
||||||
ctx_len = len(input_ids) - len(choice_ids)
|
|
||||||
else:
|
|
||||||
ctx_len = len(context_ids)
|
|
||||||
|
|
||||||
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
|
|
||||||
with torch.inference_mode():
|
|
||||||
logits = model(input_tensor)["logits"][0]
|
|
||||||
|
|
||||||
score = 0.0
|
|
||||||
for i, tid in enumerate(choice_ids):
|
|
||||||
pos = ctx_len - 1 + i
|
|
||||||
if pos >= len(logits):
|
|
||||||
break
|
|
||||||
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
|
|
||||||
return score
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_subject(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
subject: str,
|
|
||||||
test_data: list[dict],
|
|
||||||
dev_data: list[dict] | None,
|
|
||||||
device: str,
|
|
||||||
n_shot: int,
|
|
||||||
) -> tuple[float, int, int]:
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
|
|
||||||
raw_prompt = build_prompt(
|
|
||||||
item["question"], item, subject, n_shot, dev_data or []
|
|
||||||
)
|
|
||||||
context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or [])
|
|
||||||
context_ids = tokenizer.encode(context)
|
|
||||||
scores = {
|
|
||||||
c: choice_logprob(model, tokenizer, context_ids, c, device)
|
|
||||||
for c in ("A", "B", "C", "D")
|
|
||||||
}
|
|
||||||
if max(scores, key=scores.get) == item["answer"]:
|
|
||||||
correct += 1
|
|
||||||
total += 1
|
|
||||||
return correct / total, correct, total
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="MMLU evaluation")
|
|
||||||
parser.add_argument(
|
|
||||||
"--param_path", type=str, default="./params", help="Model directory"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
|
|
||||||
)
|
|
||||||
parser.add_argument("--download", action="store_true", help="Download MMLU data")
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
|
|
||||||
)
|
|
||||||
parser.add_argument("--output", type=str, help="Output JSON path")
|
|
||||||
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
|
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
help="Device",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dtype",
|
|
||||||
type=str,
|
|
||||||
default="bfloat16" if torch.cuda.is_available() else "float32",
|
|
||||||
help="Torch dtype",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.download or not os.path.exists(args.data_dir):
|
|
||||||
download_mmlu(args.data_dir)
|
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(args.param_path)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
|
|
||||||
device = args.device
|
|
||||||
dtype = getattr(torch, args.dtype)
|
|
||||||
model.to(device=device, dtype=dtype)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
subjects = args.subjects or MMLU_SUBJECTS
|
|
||||||
results = {}
|
|
||||||
total_correct = 0
|
|
||||||
total_questions = 0
|
|
||||||
|
|
||||||
for subject in subjects:
|
|
||||||
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
|
|
||||||
test_path = os.path.join(
|
|
||||||
args.data_dir, args.split, f"{subject}_{args.split}.csv"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not os.path.exists(test_path):
|
|
||||||
print(f" Skipping {subject}: test file not found")
|
|
||||||
continue
|
|
||||||
|
|
||||||
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
|
|
||||||
test_data = load_csv(test_path)
|
|
||||||
|
|
||||||
acc, corr, tot = evaluate_subject(
|
|
||||||
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
|
|
||||||
)
|
|
||||||
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
|
|
||||||
total_correct += corr
|
|
||||||
total_questions += tot
|
|
||||||
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
|
|
||||||
|
|
||||||
overall = total_correct / total_questions if total_questions else 0
|
|
||||||
print(f"\n{'=' * 70}")
|
|
||||||
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
|
|
||||||
results["_overall"] = {
|
|
||||||
"accuracy": round(overall, 4),
|
|
||||||
"correct": total_correct,
|
|
||||||
"total": total_questions,
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.output:
|
|
||||||
with open(args.output, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(results, f, indent=2)
|
|
||||||
print(f"Results saved to {args.output}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
|
||||||
):
|
):
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model = AutoModel.from_pretrained(param_path)
|
model = AutoModel.from_pretrained(model_dir)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
with open(input_file, "r", encoding="utf-8") as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
|
|
@ -44,8 +44,8 @@ def process_file(
|
||||||
|
|
||||||
for seq in batch_encoded:
|
for seq in batch_encoded:
|
||||||
pad_len = max_len - len(seq)
|
pad_len = max_len - len(seq)
|
||||||
padded_seq = seq + [tokenizer.pad_id] * pad_len
|
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
||||||
mask = [True] * len(seq) + [False] * pad_len
|
mask = [False] * pad_len + [True] * len(seq)
|
||||||
padded_ids.append(padded_seq)
|
padded_ids.append(padded_seq)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
|
||||||
|
|
@ -88,7 +88,7 @@ def process_file(
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--param_path", type=str, required=True, help="Path to the model directory."
|
"--model_dir", type=str, required=True, help="Path to the model directory."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input_file", type=str, required=True, help="Path to the input file."
|
"--input_file", type=str, required=True, help="Path to the input file."
|
||||||
|
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
|
|
||||||
)
|
|
||||||
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
|
|
||||||
parser.add_argument(
|
|
||||||
"--config", "-c", required=True, help="Path to pipeline config JSON"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokenizer_path",
|
|
||||||
default="params",
|
|
||||||
help="Path to tokenizer directory (default: params)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = PipelineConfig.from_json(args.config)
|
|
||||||
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=args.inputs,
|
|
||||||
output_dir=args.output_dir,
|
|
||||||
tokenizer_path=args.tokenizer_path,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -18,7 +18,7 @@ def main():
|
||||||
"--reload", action="store_true", help="Enable auto-reload for development"
|
"--reload", action="store_true", help="Enable auto-reload for development"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--param_path",
|
"--param-path",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to model parameters (default: project_root/params)",
|
help="Path to model parameters (default: project_root/params)",
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,16 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import AutoRegressiveLM
|
from astrai.model import AutoRegressiveLM
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
from astrai.parallel import get_rank
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -116,12 +119,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
default=0.05,
|
default=0.05,
|
||||||
help="cross_entropy function label smoothing parameter",
|
help="cross_entropy function label smoothing parameter",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--gradient_checkpointing",
|
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
default=False,
|
|
||||||
help="Enable activation checkpointing for DecoderBlock modules.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt_interval",
|
"--ckpt_interval",
|
||||||
|
|
@ -135,36 +132,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
default="checkpoint",
|
default="checkpoint",
|
||||||
help="Directory to save checkpoints.",
|
help="Directory to save checkpoints.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--val_split",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Ratio to split from training dataset for validation (e.g. 0.05).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--val_step",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of optimizer steps between validation runs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--metrics",
|
|
||||||
nargs="*",
|
|
||||||
default=["loss", "lr"],
|
|
||||||
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--log_dir",
|
|
||||||
type=str,
|
|
||||||
default="checkpoint/logs",
|
|
||||||
help="Directory for metric logs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--log_interval",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="Number of batch iterations between metric logs.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grpo_sync_interval",
|
"--grpo_sync_interval",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -178,32 +145,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--start_batch", type=int, default=0, help="Start batch for training."
|
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--master_addr",
|
|
||||||
type=str,
|
|
||||||
default="localhost",
|
|
||||||
help="Master node address for distributed training.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--master_port",
|
|
||||||
type=str,
|
|
||||||
default="29500",
|
|
||||||
help="Master node port for distributed training.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--backend",
|
|
||||||
type=str,
|
|
||||||
default="nccl",
|
|
||||||
help="Distributed training backend.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
parser.add_argument(
|
|
||||||
"--parallel_mode",
|
|
||||||
type=str,
|
|
||||||
default="none",
|
|
||||||
choices=["none", "ddp", "fsdp"],
|
|
||||||
help="Parallel training strategy (none, ddp, fsdp).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||||
)
|
)
|
||||||
|
|
@ -220,11 +162,21 @@ def parse_args() -> argparse.Namespace:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def create_model(config):
|
def ddp_wrap(model: nn.Module):
|
||||||
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
local_rank = get_rank()
|
||||||
|
ddp_model = DDP(
|
||||||
|
model,
|
||||||
|
device_ids=[local_rank],
|
||||||
|
output_device=local_rank,
|
||||||
|
static_graph=True,
|
||||||
|
find_unused_parameters=False,
|
||||||
|
gradient_as_bucket_view=True,
|
||||||
|
broadcast_buffers=False,
|
||||||
|
)
|
||||||
|
return ddp_model
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -234,6 +186,12 @@ def create_scheduler(
|
||||||
return SchedulerFactory.create(optimizer, **kwargs)
|
return SchedulerFactory.create(optimizer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.module.state_dict()
|
||||||
|
return model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
def compute_total_steps(
|
def compute_total_steps(
|
||||||
dataset_len: int,
|
dataset_len: int,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
|
|
@ -264,11 +222,6 @@ def train(
|
||||||
warmup_ratio: float,
|
warmup_ratio: float,
|
||||||
ckpt_interval: int,
|
ckpt_interval: int,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
val_split: float,
|
|
||||||
val_step: int,
|
|
||||||
metrics: list[str],
|
|
||||||
log_dir: str,
|
|
||||||
log_interval: int,
|
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
grpo_clip_eps: float,
|
grpo_clip_eps: float,
|
||||||
grpo_kl_coef: float,
|
grpo_kl_coef: float,
|
||||||
|
|
@ -282,21 +235,14 @@ def train(
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
gradient_checkpointing: bool,
|
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
parallel_mode: str,
|
|
||||||
device_type: str,
|
device_type: str,
|
||||||
backend: str,
|
|
||||||
master_addr: str,
|
|
||||||
master_port: str,
|
|
||||||
start_method: str,
|
start_method: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
if nprocs > 1 and parallel_mode == "none":
|
|
||||||
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
|
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config_path = os.path.join(param_path, "config.json")
|
config_path = os.path.join(param_path, "config.json")
|
||||||
|
|
@ -305,6 +251,17 @@ def train(
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = config.max_len
|
window_size = config.max_len
|
||||||
|
|
||||||
|
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||||
|
model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
|
# Load weights if available
|
||||||
|
weights_path = os.path.join(param_path, "model.safetensors")
|
||||||
|
if os.path.exists(weights_path):
|
||||||
|
state_dict = st.load_file(weights_path)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
model = model.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"beta": dpo_beta,
|
"beta": dpo_beta,
|
||||||
"label_smoothing": label_smoothing,
|
"label_smoothing": label_smoothing,
|
||||||
|
|
@ -314,12 +271,6 @@ def train(
|
||||||
"sync_interval": grpo_sync_interval,
|
"sync_interval": grpo_sync_interval,
|
||||||
}
|
}
|
||||||
|
|
||||||
executor_kwargs = {
|
|
||||||
"gradient_as_bucket_view": True,
|
|
||||||
"broadcast_buffers": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
model_fn = partial(create_model, config)
|
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
|
|
@ -350,10 +301,8 @@ def train(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
|
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model_fn=model_fn,
|
model=model,
|
||||||
strategy=train_type,
|
strategy=train_type,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
|
|
@ -370,24 +319,15 @@ def train(
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
backend=backend,
|
parallel_wrapper=ddp_wrap,
|
||||||
master_addr=master_addr,
|
state_dict_fn=prepare_checkpoint,
|
||||||
master_port=master_port,
|
|
||||||
parallel_mode=parallel_mode,
|
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
val_split=val_split,
|
|
||||||
val_step=val_step,
|
|
||||||
metrics=metrics,
|
|
||||||
log_dir=log_dir,
|
|
||||||
log_interval=log_interval,
|
|
||||||
gradient_checkpointing_modules=grad_ckpt_modules,
|
|
||||||
executor_kwargs=executor_kwargs,
|
|
||||||
extra_kwargs=strategy_kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train(resume_dir=param_path)
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,202 +0,0 @@
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
_SPECIAL_TOKENS_CONFIG = {
|
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
|
||||||
"eos_token": "<|end_of_sentence|>",
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
"im_start": "<|im_start|>",
|
|
||||||
"im_end": "<|im_end|>",
|
|
||||||
}
|
|
||||||
|
|
||||||
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
|
||||||
|
|
||||||
_CHAT_TEMPLATE = (
|
|
||||||
"{% for message in messages %}"
|
|
||||||
"{% if message['role'] == 'system' %}"
|
|
||||||
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% elif message['role'] == 'user' %}"
|
|
||||||
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% elif message['role'] == 'assistant' %}"
|
|
||||||
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% endif %}"
|
|
||||||
"{% endfor %}"
|
|
||||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
|
||||||
|
|
||||||
_INSTRUCTION_SECTIONS = [
|
|
||||||
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
|
||||||
{"field": "response", "action": "train"},
|
|
||||||
]
|
|
||||||
|
|
||||||
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
|
||||||
|
|
||||||
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chat_tokenizer():
|
|
||||||
tok = Tokenizer(models.BPE())
|
|
||||||
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
|
||||||
tr = trainers.BpeTrainer(
|
|
||||||
vocab_size=512,
|
|
||||||
min_frequency=1,
|
|
||||||
special_tokens=_SPECIAL_TOKENS,
|
|
||||||
)
|
|
||||||
train_data = [
|
|
||||||
"hello world",
|
|
||||||
"Hi there!",
|
|
||||||
"You are helpful.",
|
|
||||||
"What is 2+2?",
|
|
||||||
"Tell me a story about dragons and knights.",
|
|
||||||
"Sure, here is a tale.",
|
|
||||||
"Translate to French: Hello",
|
|
||||||
"Bonjour",
|
|
||||||
"Artificial Intelligence is a field of computer science.",
|
|
||||||
"system",
|
|
||||||
"user",
|
|
||||||
"assistant",
|
|
||||||
"<|im_start|>",
|
|
||||||
"<|im_end|>",
|
|
||||||
*[chr(i) for i in range(32, 127)],
|
|
||||||
]
|
|
||||||
tok.train_from_iterator(train_data, tr)
|
|
||||||
|
|
||||||
auto_tok = AutoTokenizer()
|
|
||||||
auto_tok._tokenizer = tok
|
|
||||||
auto_tok._special_token_map = {
|
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
|
||||||
"eos_token": "<|end_of_sentence|>",
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
|
||||||
return auto_tok
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def chat_tokenizer():
|
|
||||||
return _build_chat_tokenizer()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
d = tempfile.mkdtemp()
|
|
||||||
yield d
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(d, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
def make_chat_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_instruction_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_text_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(
|
|
||||||
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_dpo_chat_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
sources={
|
|
||||||
"chosen": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "chosen", "action": "$role", "template": True}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"rejected": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "rejected", "action": "$role", "template": True}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
),
|
|
||||||
mask={"user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_grpo_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
sources={
|
|
||||||
"prompts": {
|
|
||||||
"sections": [
|
|
||||||
{"field": "prompt", "action": "mask", "template": True}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"sections": _GRPO_RESPONSE_SECTIONS,
|
|
||||||
"list_field": True,
|
|
||||||
"mask_key": "masks",
|
|
||||||
},
|
|
||||||
"rewards": {
|
|
||||||
"sections": [{"field": "rewards", "action": "value"}],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
),
|
|
||||||
mask={"user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_grpo_no_template_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
sources={
|
|
||||||
"prompts": {
|
|
||||||
"sections": [
|
|
||||||
{
|
|
||||||
"field": "prompt",
|
|
||||||
"action": "mask",
|
|
||||||
"add_special_tokens": True,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"sections": _GRPO_RESPONSE_SECTIONS,
|
|
||||||
"list_field": True,
|
|
||||||
"mask_key": "masks",
|
|
||||||
},
|
|
||||||
"rewards": {
|
|
||||||
"sections": [{"field": "rewards", "action": "value"}],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
),
|
|
||||||
mask={"user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -37,6 +36,7 @@ def test_single_process():
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_with_extra():
|
def test_checkpoint_with_extra():
|
||||||
|
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
@ -52,6 +52,8 @@ def test_checkpoint_with_extra():
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
checkpoint.save(tmpdir)
|
checkpoint.save(tmpdir)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -6,11 +7,12 @@ import torch
|
||||||
|
|
||||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
BaseSegmentFetcher,
|
||||||
StoreFactory,
|
H5Storage,
|
||||||
|
MultiSegmentFetcher,
|
||||||
|
StorageFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_bin,
|
load_json,
|
||||||
save_bin,
|
|
||||||
save_h5,
|
save_h5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -98,7 +100,6 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir, "sft_data", dummy_data)
|
save_h5(test_dir, "sft_data", dummy_data)
|
||||||
|
|
@ -156,6 +157,111 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
assert len(dataset) > len(default_stride_dataset)
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tokenizer_fn(tokenizer):
|
||||||
|
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
||||||
|
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seq_dataset_from_json_text(base_test_env):
|
||||||
|
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
||||||
|
tokenizer = base_test_env["tokenizer"]
|
||||||
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "json_text")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"hello world this is a test sentence for tokenizer",
|
||||||
|
"another sentence with different words and tokens",
|
||||||
|
"machine learning is fascinating and powerful",
|
||||||
|
]
|
||||||
|
|
||||||
|
json_path = os.path.join(data_dir, "seq_data.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=data_dir,
|
||||||
|
window_size=16,
|
||||||
|
tokenizer=tokenizer_fn,
|
||||||
|
)
|
||||||
|
assert dataset is not None
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count > 0
|
||||||
|
assert "sequence" in dataset.keys
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert "target_ids" in item
|
||||||
|
assert item["input_ids"].shape[0] == 16
|
||||||
|
|
||||||
|
|
||||||
|
def test_sft_dataset_from_json_text(base_test_env):
|
||||||
|
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
||||||
|
tokenizer = base_test_env["tokenizer"]
|
||||||
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "json_sft")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"user asks a question about the weather",
|
||||||
|
"assistant provides a helpful response to the user",
|
||||||
|
]
|
||||||
|
|
||||||
|
json_path = os.path.join(data_dir, "sft_data.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
{"sequence": texts, "loss_mask": texts},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
train_type="sft",
|
||||||
|
load_path=data_dir,
|
||||||
|
window_size=16,
|
||||||
|
tokenizer=tokenizer_fn,
|
||||||
|
)
|
||||||
|
assert dataset is not None
|
||||||
|
assert len(dataset) > 0
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert "loss_mask" in item
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_storage_explicit_tokenizer(base_test_env):
|
||||||
|
"""Test explicit JSON storage with tokenizer"""
|
||||||
|
tokenizer = base_test_env["tokenizer"]
|
||||||
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "json_explicit")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
||||||
|
|
||||||
|
json_path = os.path.join(data_dir, "data.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
token_count = len(tokenizer_fn(texts[0]))
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=data_dir,
|
||||||
|
window_size=32,
|
||||||
|
storage_type="json",
|
||||||
|
tokenizer=tokenizer_fn,
|
||||||
|
)
|
||||||
|
assert dataset is not None
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == token_count
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_count_property(base_test_env):
|
def test_dataset_count_property(base_test_env):
|
||||||
"""Test the count property returns correct raw token count"""
|
"""Test the count property returns correct raw token count"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
@ -212,29 +318,37 @@ def test_unloaded_dataset_len():
|
||||||
assert len(dataset) == 0
|
assert len(dataset) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_store_unloaded_len():
|
def test_base_segment_fetcher_empty():
|
||||||
"""Unloaded Store has __len__ == 0"""
|
"""BaseSegmentFetcher with empty segments list"""
|
||||||
store = H5Store()
|
fetcher = BaseSegmentFetcher([])
|
||||||
assert len(store) == 0
|
assert len(fetcher) == 0
|
||||||
assert store.keys == []
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
fetcher.fetch_data(0, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_store_fetch_begin_equals_end(base_test_env):
|
def test_base_segment_fetcher_begin_equals_end(base_test_env):
|
||||||
"""Store.fetch with begin == end returns empty tensor"""
|
"""fetch_data with begin == end returns empty tensor"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
||||||
save_h5(test_dir, "empty_fetch", dummy)
|
save_h5(test_dir, "empty_fetch", dummy)
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
||||||
result = dataset.storage.fetch(10, 10, "sequence")
|
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
|
||||||
|
result = fetcher.fetch_data(10, 10)
|
||||||
assert result.numel() == 0
|
assert result.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_store_fetch_before_load():
|
def test_multi_segment_fetcher_empty_dict():
|
||||||
"""Store.fetch before load raises RuntimeError"""
|
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
|
||||||
store = H5Store()
|
fetcher = MultiSegmentFetcher({})
|
||||||
|
assert len(fetcher) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_fetch_before_load():
|
||||||
|
"""BaseStorage.fetch before load raises RuntimeError"""
|
||||||
|
storage = H5Storage()
|
||||||
with pytest.raises(RuntimeError, match="not loaded"):
|
with pytest.raises(RuntimeError, match="not loaded"):
|
||||||
store.fetch(0, 10, "sequence")
|
storage.fetch(0, 10, "sequence")
|
||||||
|
|
||||||
|
|
||||||
def test_detect_format_nonexistent_path():
|
def test_detect_format_nonexistent_path():
|
||||||
|
|
@ -253,192 +367,54 @@ def test_detect_format_unsupported_file(base_test_env):
|
||||||
detect_format(path)
|
detect_format(path)
|
||||||
|
|
||||||
|
|
||||||
def test_create_store_invalid_type():
|
def test_create_storage_invalid_type():
|
||||||
"""StoreFactory.create raises ValueError for unknown type"""
|
"""StorageFactory.create raises ValueError for unknown type"""
|
||||||
with pytest.raises(ValueError, match="Unknown component"):
|
with pytest.raises(ValueError, match="Unknown component"):
|
||||||
StoreFactory.create("parquet")
|
StorageFactory.create("parquet")
|
||||||
|
|
||||||
|
|
||||||
def test_store_multi_segment_concat(base_test_env):
|
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||||
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
||||||
import os
|
|
||||||
|
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
data_dir = os.path.join(test_dir, "multi_seg")
|
data_dir = os.path.join(test_dir, "json_pretok")
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
json_path = os.path.join(data_dir, "data.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == 10
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
||||||
|
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_json_skips_config_file(base_test_env):
|
||||||
|
"""load_json skips scalar-value config files"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
||||||
|
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
||||||
|
|
||||||
|
with open(os.path.join(test_dir, "data.json"), "w") as f:
|
||||||
|
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
||||||
|
|
||||||
|
result = load_json(test_dir)
|
||||||
|
assert "sequence" in result
|
||||||
|
assert "vocab_size" not in result
|
||||||
|
assert len(result["sequence"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_segment_fetcher_multi_segment():
|
||||||
|
"""fetch_data across multiple segment boundaries"""
|
||||||
segs = [
|
segs = [
|
||||||
torch.tensor([1, 2, 3]),
|
torch.tensor([1, 2, 3]),
|
||||||
torch.tensor([4, 5, 6, 7]),
|
torch.tensor([4, 5, 6, 7]),
|
||||||
torch.tensor([8, 9]),
|
torch.tensor([8, 9]),
|
||||||
]
|
]
|
||||||
save_h5(data_dir, "data", {"sequence": segs})
|
fetcher = BaseSegmentFetcher(segs)
|
||||||
|
assert len(fetcher) == 9
|
||||||
store = StoreFactory.create("h5")
|
result = fetcher.fetch_data(2, 7)
|
||||||
store.load(data_dir)
|
|
||||||
assert len(store) == 9
|
|
||||||
result = store.fetch(2, 7, "sequence")
|
|
||||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||||
|
|
||||||
|
|
||||||
def test_save_load_bin_roundtrip(base_test_env):
|
|
||||||
"""save_bin + load_bin roundtrip preserves data"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
|
|
||||||
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
|
|
||||||
}
|
|
||||||
save_bin(test_dir, data)
|
|
||||||
result = load_bin(test_dir)
|
|
||||||
|
|
||||||
assert "sequence" in result
|
|
||||||
assert "loss_mask" in result
|
|
||||||
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
|
|
||||||
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
|
|
||||||
|
|
||||||
|
|
||||||
def test_mmap_store_load_and_fetch(base_test_env):
|
|
||||||
"""MmapStore loads bin data and fetches correctly"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
|
||||||
}
|
|
||||||
save_bin(test_dir, data)
|
|
||||||
|
|
||||||
store = StoreFactory.create("bin")
|
|
||||||
store.load(test_dir)
|
|
||||||
assert len(store) == 200
|
|
||||||
assert "sequence" in store.keys
|
|
||||||
|
|
||||||
result = store.fetch(10, 20, "sequence")
|
|
||||||
assert result.tolist() == data["sequence"][0][10:20].tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def test_mmap_dataset_load(base_test_env):
|
|
||||||
"""DatasetFactory.load auto-detects bin format"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
|
||||||
}
|
|
||||||
save_bin(test_dir, data)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == 200
|
|
||||||
assert dataset[0]["input_ids"].shape[0] == 64
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_empty_key():
|
|
||||||
"""_normalize with empty tensor list does not crash"""
|
|
||||||
store = H5Store()
|
|
||||||
store._normalize({"sequence": []})
|
|
||||||
assert len(store) == 0
|
|
||||||
assert store.keys == ["sequence"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_mixed_empty_key():
|
|
||||||
"""_normalize with empty + non-empty keys returns min=0"""
|
|
||||||
store = H5Store()
|
|
||||||
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
|
|
||||||
assert len(store) == 0
|
|
||||||
assert set(store.keys) == {"sequence", "loss_mask"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_grpo_dataset_dtype(base_test_env):
|
|
||||||
"""GRPODataset returns correct dtypes"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
|
|
||||||
seq_len = 100
|
|
||||||
data = {
|
|
||||||
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
|
||||||
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
|
||||||
"masks": [torch.ones(seq_len, dtype=torch.int32)],
|
|
||||||
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
|
|
||||||
}
|
|
||||||
save_h5(test_dir, "grpo_dtype", data)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
|
|
||||||
item = dataset[0]
|
|
||||||
|
|
||||||
assert item["prompts"].dtype == torch.long
|
|
||||||
assert item["responses"].dtype == torch.long
|
|
||||||
assert item["masks"].dtype == torch.bool
|
|
||||||
assert item["rewards"].dtype == torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def test_grpo_dataset_load(base_test_env):
|
|
||||||
"""GRPODataset loads and returns correct keys"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
seq_len = 200
|
|
||||||
data = {
|
|
||||||
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
|
||||||
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
|
||||||
"masks": [torch.ones(seq_len, dtype=torch.int64)],
|
|
||||||
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
|
|
||||||
}
|
|
||||||
save_h5(test_dir, "grpo_test", data)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
|
|
||||||
assert len(dataset) > 0
|
|
||||||
item = dataset[0]
|
|
||||||
assert "prompts" in item
|
|
||||||
assert "responses" in item
|
|
||||||
assert "masks" in item
|
|
||||||
assert "rewards" in item
|
|
||||||
assert item["prompts"].shape[0] == 64
|
|
||||||
assert item["responses"].shape[0] == 64
|
|
||||||
|
|
||||||
|
|
||||||
def test_detect_format_bin_dir(base_test_env):
|
|
||||||
"""detect_format returns 'bin' for directory with .bin + meta.json"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
|
|
||||||
assert detect_format(test_dir) == "bin"
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_fetch_multi_key(base_test_env):
|
|
||||||
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
save_h5(
|
|
||||||
test_dir,
|
|
||||||
"multi_key",
|
|
||||||
{
|
|
||||||
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
|
|
||||||
"loss_mask": [torch.ones(100, dtype=torch.int64)],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
store = StoreFactory.create("h5")
|
|
||||||
store.load(test_dir)
|
|
||||||
result = store.fetch(10, 20, ["sequence", "loss_mask"])
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert result["sequence"].shape[0] == 10
|
|
||||||
assert result["loss_mask"].shape[0] == 10
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_fetch_out_of_bounds(base_test_env):
|
|
||||||
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
|
|
||||||
|
|
||||||
store = StoreFactory.create("h5")
|
|
||||||
store.load(test_dir)
|
|
||||||
with pytest.raises(ValueError, match="out of bounds"):
|
|
||||||
store.fetch(-1, 10, "sequence")
|
|
||||||
with pytest.raises(ValueError, match="out of bounds"):
|
|
||||||
store.fetch(0, 51, "sequence")
|
|
||||||
with pytest.raises(ValueError, match="out of bounds"):
|
|
||||||
store.fetch(50, 50, "sequence")
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_load_explicit_storage_type(base_test_env):
|
|
||||||
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == 200
|
|
||||||
|
|
|
||||||
|
|
@ -1,396 +0,0 @@
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
OutputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.builder import (
|
|
||||||
MaskBuilderFactory,
|
|
||||||
SectionedMaskBuilder,
|
|
||||||
)
|
|
||||||
from tests.data.conftest import (
|
|
||||||
_CHAT_SECTIONS,
|
|
||||||
_INSTRUCTION_SECTIONS,
|
|
||||||
_TEXT_SECTIONS,
|
|
||||||
make_chat_config,
|
|
||||||
make_dpo_chat_config,
|
|
||||||
make_grpo_config,
|
|
||||||
make_instruction_config,
|
|
||||||
make_text_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_simple(chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Hello."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "sequence" in result
|
|
||||||
assert "loss_mask" in result
|
|
||||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
||||||
|
|
||||||
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
|
|
||||||
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
|
||||||
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
|
||||||
|
|
||||||
total = len(result["sequence"])
|
|
||||||
trained = sum(result["loss_mask"])
|
|
||||||
assert trained > 0
|
|
||||||
assert trained < total
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_mask_only_assistant(chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["sequence"]
|
|
||||||
assert len(ids) == len(mask)
|
|
||||||
|
|
||||||
trained = [i for i, m in enumerate(mask) if m == 1]
|
|
||||||
masked = [i for i, m in enumerate(mask) if m == 0]
|
|
||||||
assert len(trained) > 0
|
|
||||||
assert len(masked) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_all_masked(chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert sum(result["loss_mask"]) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_all_trained(chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={},
|
|
||||||
mask_default="train",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_empty_messages(chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
|
||||||
assert builder.build({}, config, chat_tokenizer) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_domain_extraction(chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(domain_key="source"),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hi"},
|
|
||||||
{"role": "assistant", "content": "Hello"},
|
|
||||||
],
|
|
||||||
"source": "wiki",
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result["domain"] == "wiki"
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_truncation(chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=10),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Tell me a very long story about dragons and knights and magic.",
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert len(result["sequence"]) <= 10
|
|
||||||
assert len(result["loss_mask"]) == len(result["sequence"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_instruction_basic(test_tokenizer):
|
|
||||||
config = make_instruction_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_instruction_prompt_masked(test_tokenizer):
|
|
||||||
config = make_instruction_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "hello", "response": "world"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["sequence"]
|
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
||||||
p_len = min(len(prompt_ids), len(ids))
|
|
||||||
assert all(m == 0 for m in mask[:p_len])
|
|
||||||
if p_len < len(ids):
|
|
||||||
assert all(m == 1 for m in mask[p_len:])
|
|
||||||
|
|
||||||
|
|
||||||
def test_instruction_train_on_prompt(test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
sections=[
|
|
||||||
{"field": "prompt", "action": "train", "add_special_tokens": True},
|
|
||||||
{"field": "response", "action": "mask"},
|
|
||||||
]
|
|
||||||
),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "hello", "response": "world"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["sequence"]
|
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
||||||
p_len = min(len(prompt_ids), len(ids))
|
|
||||||
assert all(m == 1 for m in mask[:p_len])
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_basic(test_tokenizer):
|
|
||||||
config = make_text_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"text": "Hello world. This is a test document."}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "sequence" in result
|
|
||||||
assert len(result["sequence"]) > 0
|
|
||||||
assert "loss_mask" not in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_empty(test_tokenizer):
|
|
||||||
config = make_text_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
|
||||||
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_too_short(test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(min_chars=100),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_truncation(test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"text": "This is a very long text that should be truncated"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert len(result["sequence"]) <= 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_sectioned_chat(chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
||||||
assert sum(result["loss_mask"]) > 0
|
|
||||||
assert 0 in result["loss_mask"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_sectioned_instruction(test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
assert mask[0] == 0
|
|
||||||
assert mask[-1] == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_sectioned_text(test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"text": "Hello world, this is a test."}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "loss_mask" not in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_sectioned_text_too_short(test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_factory_registered():
|
|
||||||
names = MaskBuilderFactory._registry.list_names()
|
|
||||||
assert "sectioned" in names
|
|
||||||
|
|
||||||
|
|
||||||
def test_factory_create():
|
|
||||||
builder = MaskBuilderFactory.create("sectioned")
|
|
||||||
assert isinstance(builder, SectionedMaskBuilder)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dpo_chat_basic(chat_tokenizer):
|
|
||||||
config = make_dpo_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"chosen": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
],
|
|
||||||
"rejected": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "5"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "chosen" in result
|
|
||||||
assert "rejected" in result
|
|
||||||
assert "chosen_mask" in result
|
|
||||||
assert "rejected_mask" in result
|
|
||||||
assert "domain" in result
|
|
||||||
assert len(result["chosen"]) == len(result["chosen_mask"])
|
|
||||||
assert len(result["rejected"]) == len(result["rejected_mask"])
|
|
||||||
assert sum(result["chosen_mask"]) > 0
|
|
||||||
assert sum(result["rejected_mask"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_dpo_chosen_only_trained(chat_tokenizer):
|
|
||||||
config = make_dpo_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"chosen": [
|
|
||||||
{"role": "user", "content": "Hi"},
|
|
||||||
{"role": "assistant", "content": "Hello"},
|
|
||||||
],
|
|
||||||
"rejected": [
|
|
||||||
{"role": "user", "content": "Hi"},
|
|
||||||
{"role": "assistant", "content": "Go away"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert 0 in result["chosen_mask"]
|
|
||||||
assert 1 in result["chosen_mask"]
|
|
||||||
assert 0 in result["rejected_mask"]
|
|
||||||
assert 1 in result["rejected_mask"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_dpo_missing_field_is_none(chat_tokenizer):
|
|
||||||
config = make_dpo_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_grpo_basic(chat_tokenizer):
|
|
||||||
config = make_grpo_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"prompt": [{"role": "user", "content": "What is 2+2?"}],
|
|
||||||
"responses": ["4", "The answer is four", "Four", "2+2=4"],
|
|
||||||
"rewards": [1.0, 0.5, 0.8, 0.2],
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "prompts" in result
|
|
||||||
assert "responses" in result
|
|
||||||
assert "masks" in result
|
|
||||||
assert "rewards" in result
|
|
||||||
assert len(result["responses"]) == len(result["masks"])
|
|
||||||
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
|
|
||||||
|
|
||||||
|
|
||||||
def test_grpo_response_tokens_all_trained(chat_tokenizer):
|
|
||||||
config = make_grpo_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"prompt": [{"role": "user", "content": "Q"}],
|
|
||||||
"responses": ["A", "B"],
|
|
||||||
"rewards": [0.8, 0.2],
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
masks = result["masks"]
|
|
||||||
assert all(m == 1 for m in masks)
|
|
||||||
assert len(masks) == len(result["responses"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_grpo_single_reward(chat_tokenizer):
|
|
||||||
config = make_grpo_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"prompt": [{"role": "user", "content": "Q"}],
|
|
||||||
"responses": ["A"],
|
|
||||||
"rewards": 0.9,
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result["rewards"] == [0.9]
|
|
||||||
|
|
@ -1,77 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
)
|
|
||||||
from tests.data.conftest import (
|
|
||||||
_INSTRUCTION_SECTIONS,
|
|
||||||
_TEXT_SECTIONS,
|
|
||||||
make_dpo_chat_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_values():
|
|
||||||
config = PipelineConfig()
|
|
||||||
assert config.version == 1
|
|
||||||
assert config.mask == {}
|
|
||||||
assert config.mask_default == "mask"
|
|
||||||
assert config.preprocessing.max_seq_len == 2048
|
|
||||||
assert config.output.storage_format == "bin"
|
|
||||||
assert config.input.sections is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_dict_flat():
|
|
||||||
data = {
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
|
||||||
},
|
|
||||||
"mask": {"system": "mask", "assistant": "train"},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {"max_seq_len": 1024},
|
|
||||||
"output": {"storage_format": "h5"},
|
|
||||||
}
|
|
||||||
config = PipelineConfig.from_dict(data)
|
|
||||||
assert config.input.sections == [
|
|
||||||
{"field": "messages", "action": "$role", "template": True}
|
|
||||||
]
|
|
||||||
assert config.mask == {"system": "mask", "assistant": "train"}
|
|
||||||
assert config.preprocessing.max_seq_len == 1024
|
|
||||||
assert config.output.storage_format == "h5"
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_dict_roundtrip():
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
)
|
|
||||||
d = config.to_dict()
|
|
||||||
config2 = PipelineConfig.from_dict(d)
|
|
||||||
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
|
||||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_json_from_json(temp_dir):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
mask={"text": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
)
|
|
||||||
path = os.path.join(temp_dir, "config.json")
|
|
||||||
config.to_json(path)
|
|
||||||
loaded = PipelineConfig.from_json(path)
|
|
||||||
assert loaded.input.sections == _TEXT_SECTIONS
|
|
||||||
assert loaded.mask == {"text": "train"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_dpo_config_roundtrip(temp_dir):
|
|
||||||
config = make_dpo_chat_config()
|
|
||||||
path = os.path.join(temp_dir, "config.json")
|
|
||||||
config.to_json(path)
|
|
||||||
loaded = PipelineConfig.from_json(path)
|
|
||||||
assert loaded.input.sources is not None
|
|
||||||
assert "chosen" in loaded.input.sources
|
|
||||||
assert "rejected" in loaded.input.sources
|
|
||||||
assert loaded.input.sections is None
|
|
||||||
|
|
@ -1,349 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
OutputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
|
||||||
from tests.data.conftest import (
|
|
||||||
_CHAT_SECTIONS,
|
|
||||||
_CHAT_TEMPLATE,
|
|
||||||
_INSTRUCTION_SECTIONS,
|
|
||||||
_SPECIAL_TOKENS_CONFIG,
|
|
||||||
_TEXT_SECTIONS,
|
|
||||||
make_dpo_chat_config,
|
|
||||||
make_grpo_no_template_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_filter_by_length():
|
|
||||||
assert filter_by_length("hello world", min_len=5)
|
|
||||||
assert not filter_by_length("hi", min_len=5)
|
|
||||||
assert not filter_by_length("x" * 100, max_len=50)
|
|
||||||
assert filter_by_length("just right", min_len=5, max_len=20)
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
|
||||||
"chat_template": _CHAT_TEMPLATE,
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Hi."},
|
|
||||||
{"role": "assistant", "content": "Hello!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(storage_format="bin", domain_key=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" in meta
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
assert meta["loss_mask"]["dtype"] == "int32"
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_text_pipeline(temp_dir, test_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"text": "Another document for testing purposes with sufficient length to be processed."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
|
|
||||||
output=OutputConfig(storage_format="bin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" not in meta
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "Tell me a joke",
|
|
||||||
"response": "Why did the chicken cross the road?",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "What is AI?",
|
|
||||||
"response": "Artificial Intelligence is a field of computer science.",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(storage_format="bin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" in meta
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
assert meta["loss_mask"]["dtype"] == "int32"
|
|
||||||
|
|
||||||
|
|
||||||
def test_dtype_override(temp_dir, test_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "data.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
assert meta["loss_mask"]["dtype"] == "bool"
|
|
||||||
|
|
||||||
|
|
||||||
def test_dpo_pipeline(temp_dir, chat_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
|
||||||
"chat_template": _CHAT_TEMPLATE,
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"chosen": [
|
|
||||||
{"role": "user", "content": "Hi."},
|
|
||||||
{"role": "assistant", "content": "Hello!"},
|
|
||||||
],
|
|
||||||
"rejected": [
|
|
||||||
{"role": "user", "content": "Hi."},
|
|
||||||
{"role": "assistant", "content": "Go away."},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=make_dpo_chat_config(),
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "chosen" in meta
|
|
||||||
assert "rejected" in meta
|
|
||||||
assert "chosen_mask" in meta
|
|
||||||
assert "rejected_mask" in meta
|
|
||||||
assert "sequence" not in meta
|
|
||||||
|
|
||||||
|
|
||||||
def test_grpo_pipeline(temp_dir, test_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "Question?",
|
|
||||||
"responses": ["Answer A", "Answer B"],
|
|
||||||
"rewards": [0.8, 0.3],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=make_grpo_no_template_config(),
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "prompts" in meta
|
|
||||||
assert "responses" in meta
|
|
||||||
assert "masks" in meta
|
|
||||||
assert "rewards" in meta
|
|
||||||
assert "sequence" not in meta
|
|
||||||
|
|
@ -5,22 +5,21 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from astrai.inference import get_app
|
from astrai.inference import app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
"""Provide a test client for the FastAPI app."""
|
"""Provide a test client for the FastAPI app."""
|
||||||
_app = get_app()
|
app.state.server_config = {
|
||||||
_app.state.server_config = {
|
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"dtype": "bfloat16",
|
"dtype": "bfloat16",
|
||||||
"param_path": None,
|
"param_path": None,
|
||||||
"max_batch_size": 1,
|
"max_batch_size": 1,
|
||||||
"_test": True,
|
"_test": True,
|
||||||
}
|
}
|
||||||
_app.state.engine = None
|
app.state.engine = None
|
||||||
return TestClient(_app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -50,5 +49,5 @@ def mock_engine():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def loaded_model(client, mock_engine):
|
def loaded_model(client, mock_engine):
|
||||||
"""Simulate that the engine is loaded."""
|
"""Simulate that the engine is loaded."""
|
||||||
get_app().state.engine = mock_engine
|
app.state.engine = mock_engine
|
||||||
return mock_engine
|
return mock_engine
|
||||||
|
|
|
||||||
|
|
@ -1,286 +0,0 @@
|
||||||
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
|
||||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
|
||||||
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
|
|
||||||
from astrai.inference.engine import GenerationRequest
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ctx(**kwargs):
|
|
||||||
defaults = {
|
|
||||||
"resp_id": "test-123",
|
|
||||||
"created": 1000,
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 5,
|
|
||||||
}
|
|
||||||
defaults.update(kwargs)
|
|
||||||
return GenContext(**defaults)
|
|
||||||
|
|
||||||
|
|
||||||
def _sse_payloads(events):
|
|
||||||
payloads = []
|
|
||||||
for chunk in events:
|
|
||||||
for line in chunk.strip().split("\n"):
|
|
||||||
if line.startswith("data: "):
|
|
||||||
try:
|
|
||||||
payloads.append(json.loads(line[6:]))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
return payloads
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopChecker:
|
|
||||||
def test_check_finds_match(self):
|
|
||||||
sc = StopChecker(["stop", "end"])
|
|
||||||
assert sc.check("hello stop world") == "stop"
|
|
||||||
|
|
||||||
def test_check_returns_none_when_no_match(self):
|
|
||||||
sc = StopChecker(["stop"])
|
|
||||||
assert sc.check("hello world") is None
|
|
||||||
|
|
||||||
def test_check_empty_sequences(self):
|
|
||||||
sc = StopChecker([])
|
|
||||||
assert sc.check("hello") is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenContext:
|
|
||||||
def test_defaults(self):
|
|
||||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
|
||||||
assert ctx.completion_tokens == 0
|
|
||||||
|
|
||||||
def test_fields_mutable(self):
|
|
||||||
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
|
|
||||||
ctx.completion_tokens = 42
|
|
||||||
assert ctx.completion_tokens == 42
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopInfo:
|
|
||||||
def test_defaults(self):
|
|
||||||
s = StopInfo()
|
|
||||||
assert s.matched is None
|
|
||||||
assert s.body == ""
|
|
||||||
assert s.yielded == ""
|
|
||||||
|
|
||||||
def test_with_values(self):
|
|
||||||
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
|
|
||||||
assert s.matched == "stop"
|
|
||||||
assert s.body == "hello stop"
|
|
||||||
assert s.yielded == "hello "
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIResponseBuilder:
|
|
||||||
@pytest.fixture
|
|
||||||
def builder(self):
|
|
||||||
builder = OpenAIResponseBuilder()
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hello")]
|
|
||||||
req.stop = None
|
|
||||||
req.model = "astrai"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
|
||||||
builder.prepare(req, engine)
|
|
||||||
return builder
|
|
||||||
|
|
||||||
def test_prepare_returns_prompt_ctx_stops(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hi")]
|
|
||||||
req.stop = ["END"]
|
|
||||||
req.model = "gpt"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
|
||||||
prompt, ctx, stops = builder.prepare(req, engine)
|
|
||||||
assert prompt == "Hi"
|
|
||||||
assert ctx.model == "gpt"
|
|
||||||
assert ctx.prompt_tokens == 0
|
|
||||||
assert stops == ["END"]
|
|
||||||
|
|
||||||
def test_prepare_no_stop_returns_empty_list(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = []
|
|
||||||
req.stop = None
|
|
||||||
req.model = "x"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = ""
|
|
||||||
_, _, stops = builder.prepare(req, engine)
|
|
||||||
assert stops == []
|
|
||||||
|
|
||||||
def test_format_stream_start(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
events = builder.format_stream_start(ctx)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
assert len(payloads) == 1
|
|
||||||
p = payloads[0]
|
|
||||||
assert p["object"] == "chat.completion.chunk"
|
|
||||||
assert p["choices"][0]["delta"]["role"] == "assistant"
|
|
||||||
assert p["choices"][0]["finish_reason"] is None
|
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
|
||||||
event = builder.format_chunk("hello")
|
|
||||||
payload = json.loads(event.split("data: ", 1)[1])
|
|
||||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
|
||||||
assert payload["choices"][0]["finish_reason"] is None
|
|
||||||
|
|
||||||
def test_format_stream_end(self, builder):
|
|
||||||
ctx = _make_ctx(completion_tokens=5)
|
|
||||||
stop = StopInfo(matched="stop")
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
finish = payloads[0]
|
|
||||||
assert finish["choices"][0]["finish_reason"] == "stop"
|
|
||||||
usage = payloads[1]
|
|
||||||
assert usage["completion_tokens"] == 5
|
|
||||||
assert usage["total_tokens"] == 15
|
|
||||||
|
|
||||||
def test_format_response(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo()
|
|
||||||
resp = builder.format_response(ctx, "hello", stop)
|
|
||||||
assert resp["object"] == "chat.completion"
|
|
||||||
assert resp["choices"][0]["message"]["content"] == "hello"
|
|
||||||
assert resp["usage"]["prompt_tokens"] == 10
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnthropicResponseBuilder:
|
|
||||||
@pytest.fixture
|
|
||||||
def builder(self):
|
|
||||||
builder = AnthropicResponseBuilder()
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hello")]
|
|
||||||
req.model = "claude"
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hello"
|
|
||||||
req.system = None
|
|
||||||
builder.prepare(req, engine)
|
|
||||||
return builder
|
|
||||||
|
|
||||||
def test_prepare_messages(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = [MagicMock(role="user", content="Hi")]
|
|
||||||
req.model = "claude"
|
|
||||||
req.system = None
|
|
||||||
req.stop_sequences = None
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = "Hi"
|
|
||||||
prompt, ctx, stops = builder.prepare(req, engine)
|
|
||||||
assert prompt == "Hi"
|
|
||||||
assert stops == []
|
|
||||||
|
|
||||||
def test_prepare_with_stop_sequences(self, builder):
|
|
||||||
req = MagicMock()
|
|
||||||
req.messages = []
|
|
||||||
req.model = "x"
|
|
||||||
req.stop_sequences = ["stop", "end"]
|
|
||||||
req.system = None
|
|
||||||
engine = MagicMock()
|
|
||||||
engine.tokenizer.apply_chat_template.return_value = ""
|
|
||||||
_, _, stops = builder.prepare(req, engine)
|
|
||||||
assert stops == ["stop", "end"]
|
|
||||||
|
|
||||||
def test_format_stream_start(self, builder):
|
|
||||||
ctx = _make_ctx(prompt_tokens=3)
|
|
||||||
events = builder.format_stream_start(ctx)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
assert len(payloads) == 2
|
|
||||||
assert payloads[0]["type"] == "message_start"
|
|
||||||
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
|
|
||||||
assert payloads[1]["type"] == "content_block_start"
|
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
|
||||||
event = builder.format_chunk("tok")
|
|
||||||
payload = json.loads(event.split("data: ", 1)[1])
|
|
||||||
assert payload["type"] == "content_block_delta"
|
|
||||||
assert payload["delta"]["text"] == "tok"
|
|
||||||
|
|
||||||
def test_format_stream_end_no_stop(self, builder):
|
|
||||||
ctx = _make_ctx(completion_tokens=3)
|
|
||||||
stop = StopInfo()
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
# content_block_stop, message_delta, message_stop
|
|
||||||
types = [p["type"] for p in payloads]
|
|
||||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
|
||||||
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
|
|
||||||
|
|
||||||
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
|
|
||||||
ctx = _make_ctx(completion_tokens=7)
|
|
||||||
stop = StopInfo(
|
|
||||||
matched="END",
|
|
||||||
body="Hello world END extra",
|
|
||||||
yielded="Hello ",
|
|
||||||
)
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
# unyielded delta, content_block_stop, message_delta, message_stop
|
|
||||||
types = [p["type"] for p in payloads]
|
|
||||||
assert types == [
|
|
||||||
"content_block_delta",
|
|
||||||
"content_block_stop",
|
|
||||||
"message_delta",
|
|
||||||
"message_stop",
|
|
||||||
]
|
|
||||||
assert payloads[0]["delta"]["text"] == "world "
|
|
||||||
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
|
|
||||||
assert payloads[2]["delta"]["stop_sequence"] == "END"
|
|
||||||
|
|
||||||
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo(
|
|
||||||
matched="END",
|
|
||||||
body="Hello END",
|
|
||||||
yielded="Hello ",
|
|
||||||
)
|
|
||||||
events = builder.format_stream_end(ctx, stop)
|
|
||||||
payloads = _sse_payloads(events)
|
|
||||||
# No unyielded delta (everything already sent)
|
|
||||||
types = [p["type"] for p in payloads]
|
|
||||||
assert types == ["content_block_stop", "message_delta", "message_stop"]
|
|
||||||
|
|
||||||
def test_format_response_with_stop_trims_content(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
|
|
||||||
resp = builder.format_response(ctx, "text STOP extra", stop)
|
|
||||||
assert resp["content"][0]["text"] == "text "
|
|
||||||
assert resp["stop_reason"] == "stop_sequence"
|
|
||||||
assert resp["stop_sequence"] == "STOP"
|
|
||||||
|
|
||||||
def test_format_response_no_stop(self, builder):
|
|
||||||
ctx = _make_ctx()
|
|
||||||
stop = StopInfo()
|
|
||||||
resp = builder.format_response(ctx, "full text", stop)
|
|
||||||
assert resp["content"][0]["text"] == "full text"
|
|
||||||
assert resp["stop_reason"] == "end_turn"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenerationRequestValidation:
|
|
||||||
def test_valid_params(self):
|
|
||||||
gr = GenerationRequest(
|
|
||||||
messages=[{"role": "user", "content": "hi"}],
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.9,
|
|
||||||
temperature=0.7,
|
|
||||||
)
|
|
||||||
assert gr.top_k == 50
|
|
||||||
|
|
||||||
def test_invalid_top_p_raises(self):
|
|
||||||
with pytest.raises(ValueError, match="top_p"):
|
|
||||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
|
|
||||||
|
|
||||||
def test_invalid_top_k_raises(self):
|
|
||||||
with pytest.raises(ValueError, match="top_k"):
|
|
||||||
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
|
|
||||||
|
|
||||||
def test_invalid_temperature_raises(self):
|
|
||||||
with pytest.raises(ValueError, match="temperature"):
|
|
||||||
GenerationRequest(
|
|
||||||
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_top_k_zero_valid(self):
|
|
||||||
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
|
|
||||||
assert gr.top_k == 0
|
|
||||||
|
|
@ -173,21 +173,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
for stats in results["stats"]:
|
for stats in results["stats"]:
|
||||||
assert "total_tasks" in stats
|
assert "total_tasks" in stats
|
||||||
assert stats["total_tasks"] >= 0
|
assert stats["total_tasks"] >= 0
|
||||||
|
|
||||||
|
|
||||||
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
|
|
||||||
"""Tasks whose entire prompt is cached skip the prefill phase."""
|
|
||||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
||||||
|
|
||||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
|
||||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
|
||||||
scheduler = InferenceScheduler(
|
|
||||||
model=mock_model,
|
|
||||||
tokenizer=mock_tokenizer,
|
|
||||||
max_batch_size=4,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
|
|
||||||
scheduler.stop()
|
|
||||||
assert task_id.startswith("task_")
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from astrai.inference import get_app
|
from astrai.inference import app
|
||||||
|
|
||||||
|
|
||||||
def test_health_no_model(client):
|
def test_health_no_model(client):
|
||||||
"""GET /health should return 200 even when engine not loaded."""
|
"""GET /health should return 200 even when engine not loaded."""
|
||||||
get_app().state.engine = None
|
app.state.engine = None
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
get_app().state.engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
get_app().state.engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
get_app().state.engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
get_app().state.engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Reply"
|
yield "Reply"
|
||||||
|
|
||||||
get_app().state.engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -165,7 +165,7 @@ def test_chat_completions_stop_sequence(client, loaded_model):
|
||||||
yield "X"
|
yield "X"
|
||||||
yield "world"
|
yield "world"
|
||||||
|
|
||||||
get_app().state.engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -191,7 +191,7 @@ def test_chat_completions_stop_sequence_stream(client, loaded_model):
|
||||||
yield "X"
|
yield "X"
|
||||||
yield "world"
|
yield "world"
|
||||||
|
|
||||||
get_app().state.engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
|
||||||
|
|
@ -1,355 +0,0 @@
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
|
||||||
from astrai.model import AutoRegressiveLM
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
from astrai.model.components.lora import (
|
|
||||||
LoRAConfig,
|
|
||||||
LoRALinear,
|
|
||||||
_collect_lora_info,
|
|
||||||
_get_lora_count,
|
|
||||||
inject_lora,
|
|
||||||
load_lora,
|
|
||||||
merge_lora,
|
|
||||||
save_lora,
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_KWARGS = dict(
|
|
||||||
vocab_size=1000,
|
|
||||||
dim=64,
|
|
||||||
n_heads=4,
|
|
||||||
n_kv_heads=2,
|
|
||||||
dim_ffn=128,
|
|
||||||
n_layers=2,
|
|
||||||
max_len=32,
|
|
||||||
norm_eps=1e-5,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_model(**kwargs):
|
|
||||||
kw = {**MODEL_KWARGS, **kwargs}
|
|
||||||
config = AutoRegressiveLMConfig(**kw)
|
|
||||||
model = AutoRegressiveLM(config)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_init():
|
|
||||||
base = Linear(64, 128)
|
|
||||||
lora = LoRALinear(base, r=8, alpha=16)
|
|
||||||
|
|
||||||
assert lora.weight is base.weight
|
|
||||||
assert not lora.weight.requires_grad
|
|
||||||
assert lora.lora_A.shape == (8, 64)
|
|
||||||
assert lora.lora_B.shape == (128, 8)
|
|
||||||
assert lora.scaling == 2.0
|
|
||||||
assert not lora._merged
|
|
||||||
assert lora.lora_A.requires_grad
|
|
||||||
assert lora.lora_B.requires_grad
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_forward_init_zero_delta():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
x = torch.randn(2, 4)
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
base_out = base(x)
|
|
||||||
lora_out = lora(x)
|
|
||||||
|
|
||||||
assert torch.allclose(base_out, lora_out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_forward_with_delta():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
x = torch.randn(2, 4)
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
base_out = base(x)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
lora.lora_B.fill_(1.0)
|
|
||||||
|
|
||||||
lora_out = lora(x)
|
|
||||||
assert not torch.allclose(base_out, lora_out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_merge():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
x = torch.randn(2, 4)
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
with torch.no_grad():
|
|
||||||
lora.lora_B.fill_(1.0)
|
|
||||||
|
|
||||||
out_before = lora(x).clone()
|
|
||||||
lora.merge()
|
|
||||||
out_after = lora(x)
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_before, out_after)
|
|
||||||
assert lora._merged
|
|
||||||
assert not hasattr(lora, "lora_A")
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_merge_is_idempotent():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
with torch.no_grad():
|
|
||||||
lora.lora_B.fill_(1.0)
|
|
||||||
|
|
||||||
lora.merge()
|
|
||||||
lora.merge()
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_default_target():
|
|
||||||
model = _make_model()
|
|
||||||
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8)
|
|
||||||
|
|
||||||
lora_count = _get_lora_count(model)
|
|
||||||
assert lora_count > 0
|
|
||||||
assert lora_count < n_before
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_ffn():
|
|
||||||
model = _make_model()
|
|
||||||
from astrai.model.components.lora import TARGET_MODULES_FFN
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
|
|
||||||
assert _get_lora_count(model) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_returns_config():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=8, alpha=32)
|
|
||||||
assert isinstance(cfg, LoRAConfig)
|
|
||||||
assert cfg.r == 8
|
|
||||||
assert cfg.alpha == 32
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_no_matching_targets_warns(caplog):
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
|
|
||||||
assert "No LoRA layers injected" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_preserves_base_output():
|
|
||||||
model = _make_model()
|
|
||||||
x = torch.randint(0, 1000, (2, 16))
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_before = model(x)["logits"].clone()
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_after = model(x)["logits"]
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_before, out_after)
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_does_not_reinject():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
first_count = _get_lora_count(model)
|
|
||||||
|
|
||||||
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
|
|
||||||
assert _get_lora_count(model) == first_count
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_adds_new_modules():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
first = _get_lora_count(model)
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
|
|
||||||
assert _get_lora_count(model) > first
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_on_mla_model():
|
|
||||||
model = _make_model(
|
|
||||||
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
|
|
||||||
)
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
|
|
||||||
assert _get_lora_count(model) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_on_moe_model():
|
|
||||||
model = _make_model(
|
|
||||||
ffn_type="moe",
|
|
||||||
n_routed_experts=4,
|
|
||||||
n_shared_experts=1,
|
|
||||||
n_activated_experts=2,
|
|
||||||
dim_ffn=32,
|
|
||||||
)
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
|
|
||||||
assert _get_lora_count(model) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_dict_key_format():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
sd = model.state_dict()
|
|
||||||
assert "layers.0.attention.q_proj.weight" in sd
|
|
||||||
assert "layers.0.attention.q_proj.lora_A" in sd
|
|
||||||
assert "layers.0.attention.q_proj.lora_B" in sd
|
|
||||||
|
|
||||||
|
|
||||||
def test_only_lora_params_trainable():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if isinstance(name.split(".")[-1], str) and "lora" in name:
|
|
||||||
assert param.requires_grad, f"lora param should be trainable: {name}"
|
|
||||||
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
|
|
||||||
assert not param.requires_grad, f"injected weight should be frozen: {name}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_dict_after_inject_consistent_with_original():
|
|
||||||
model = _make_model()
|
|
||||||
sd_before = {k: v for k, v in model.state_dict().items()}
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
sd_after = model.state_dict()
|
|
||||||
|
|
||||||
# original keys unchanged
|
|
||||||
for k in sd_before:
|
|
||||||
assert k in sd_after
|
|
||||||
assert sd_before[k].shape == sd_after[k].shape
|
|
||||||
|
|
||||||
# new lora keys present
|
|
||||||
lora_keys = [k for k in sd_after if "lora" in k]
|
|
||||||
assert len(lora_keys) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_load_roundtrip():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
x = torch.randint(0, 1000, (2, 16))
|
|
||||||
with torch.no_grad():
|
|
||||||
out_src = model(x)["logits"].clone()
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, cfg)
|
|
||||||
|
|
||||||
model2 = _make_model()
|
|
||||||
model2.load_state_dict(model.state_dict(), strict=False)
|
|
||||||
load_lora(model2, tmpdir)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_dst = model2(x)["logits"]
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_src, out_dst)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_after_merge_raises():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, cfg)
|
|
||||||
merge_lora(model)
|
|
||||||
|
|
||||||
tmpdir2 = tempfile.mkdtemp()
|
|
||||||
with pytest.raises(RuntimeError, match="No LoRA parameters"):
|
|
||||||
save_lora(model, tmpdir2, cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_lora_on_already_injected():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
|
|
||||||
|
|
||||||
model2 = _make_model()
|
|
||||||
model2.load_state_dict(model.state_dict(), strict=False)
|
|
||||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
# load onto already-injected model
|
|
||||||
load_lora(model2, tmpdir)
|
|
||||||
assert _get_lora_count(model2) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_lora_mismatched_r_raises():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, cfg)
|
|
||||||
|
|
||||||
model2 = _make_model()
|
|
||||||
model2.load_state_dict(model.state_dict(), strict=False)
|
|
||||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="size mismatch"):
|
|
||||||
load_lora(model2, tmpdir) # strict=False, only lora keys
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_preserves_output():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
x = torch.randint(0, 1000, (2, 16))
|
|
||||||
with torch.no_grad():
|
|
||||||
out_before = model(x)["logits"].clone()
|
|
||||||
|
|
||||||
merge_lora(model)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_after = model(x)["logits"]
|
|
||||||
torch.testing.assert_close(out_before, out_after)
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_no_lora_warns(caplog):
|
|
||||||
model = _make_model()
|
|
||||||
merge_lora(model)
|
|
||||||
assert "No LoRA layers to merge" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_collect_lora_info():
|
|
||||||
model = _make_model()
|
|
||||||
info = _collect_lora_info(model)
|
|
||||||
assert "q_proj" in info
|
|
||||||
assert "o_proj" in info
|
|
||||||
assert "q_proj" in info # each layer has one
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
@ -27,7 +25,7 @@ class TrainerDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
def create_train_config(
|
def create_train_config(
|
||||||
model_fn,
|
model: torch.nn.Module,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
test_dir: str,
|
test_dir: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
|
@ -43,7 +41,7 @@ def create_train_config(
|
||||||
"""Factory function to create common TrainConfig for tests.
|
"""Factory function to create common TrainConfig for tests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_fn: Model factory (callable returning nn.Module)
|
model: The model to train
|
||||||
dataset: Training dataset
|
dataset: Training dataset
|
||||||
test_dir: Checkpoint directory
|
test_dir: Checkpoint directory
|
||||||
device: Device type ("cuda" or "cpu")
|
device: Device type ("cuda" or "cpu")
|
||||||
|
|
@ -70,12 +68,11 @@ def create_train_config(
|
||||||
|
|
||||||
return TrainConfig(
|
return TrainConfig(
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
model_fn=model_fn,
|
model=model,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=test_dir,
|
ckpt_dir=test_dir,
|
||||||
log_dir=os.path.join(test_dir, "logs"),
|
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_per_device=batch_per_device,
|
batch_per_device=batch_per_device,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
@ -106,13 +104,12 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model_fn=lambda: base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_per_device=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=3,
|
ckpt_interval=3,
|
||||||
|
|
@ -140,13 +137,12 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model_fn=lambda: base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_per_device=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=3,
|
ckpt_interval=3,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.schedule import SchedulerFactory
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
@ -23,10 +24,9 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
model_fn=lambda: base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_per_device=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=1,
|
ckpt_interval=1,
|
||||||
|
|
@ -38,20 +38,17 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
||||||
# Should handle early stopping gracefully
|
# Should handle early stopping gracefully
|
||||||
|
checkpoint = None
|
||||||
try:
|
try:
|
||||||
trainer.train()
|
checkpoint = trainer.train()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
# Handle any exceptions
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Resume from latest checkpoint
|
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||||
trainer = Trainer(train_config)
|
checkpoint = Checkpoint.load(load_dir)
|
||||||
trainer.train(resume_dir=load_dir)
|
trainer.train(checkpoint)
|
||||||
|
|
||||||
# Verify checkpoint was saved at expected iteration
|
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||||
import json
|
checkpoint = Checkpoint.load(load_dir)
|
||||||
|
assert checkpoint.iteration == 10
|
||||||
with open(os.path.join(load_dir, "meta.json")) as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert meta["iteration"] == 10
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
||||||
|
|
||||||
for batch_per_device in batch_sizes:
|
for batch_per_device in batch_sizes:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model_fn=lambda: base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
|
|
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
|
||||||
|
|
||||||
for grad_accum_steps in grad_accum_steps_list:
|
for grad_accum_steps in grad_accum_steps_list:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model_fn=lambda: base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
|
|
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model_fn=lambda: base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue