Compare commits

..

113 Commits
v1.3.5 ... main

Author SHA1 Message Date
ViperEkura 8ab7564d02 docs: 重构 README 结构,全文档添加目录导航
- README 新增 Getting Started 端到端流程,整合快速开始与演示,去重精简
- 中文 README 同步英文版结构,预处理配置改用 seq 策略
- inference.md 补充 SSE 流式格式、错误响应、/stats 端点文档
- params.md 扩展为 CLI 参考,覆盖 server/generate/preprocess 参数表
- dataflow.md 拆分 tokenization/format detection/backend 子节,新增流程图
- architecture/training/inference/preprocessing 均添加目录导航
- 移除 README CI badge
2026-06-19 13:53:22 +08:00
ViperEkura d096b6e29e docs: 修复文档中过时的字段、签名和缺失的类
- BaseConfig 的 from_json/to_json → from_file/to_file
- InputConfig/ProcessingConfig/OutputConfig 字段对齐源码
- 移除不存在的 Registry 类,register() 去 category/priority
- SchedulerFactory.create 参数顺序修正
- 架构图/训练/参数文档补全 WSDScheduler
- CONTRIBUTING.md 克隆地址占位符修正
- params.md label_smoothing 默认值修正,补全 neftune_alpha
- app 类更正为 get_app 函数
2026-06-18 18:49:46 +08:00
ViperEkura d88a41f8f1 fix: 修复预处理流水线 4 个致命问题
- pipeline: 单条数据异常不再崩溃整条流水线, 改 log warning 后跳过
- pipeline: _align_bucket 统一用 len(ids) 填充, 修复多输出模式下长度错配
- writer: BinWriter/H5Writer 写入失败自动清理残留文件并记录详细错误
- packing: BFDPacking 真正将序列打包进 bin 而非仅重排, 减少碎片
2026-06-18 17:38:01 +08:00
ViperEkura 376e9eba80 feat: IFEval 使用 chat template 格式化 prompt,添加 model.eval()
- generate_one 用 tokenizer.apply_chat_template 包 user 消息
- 新增 model.eval() 关闭 dropout,确保确定性输出
2026-06-18 16:45:16 +08:00
ViperEkura a62c2e11a2 feat: IFD 默认使用 chat template,支持裸文本模式
- 新增 _compute_ifd_with_template,用 tokenizer chat template 格式化后计算 IFD
- 默认开启 chat template,可通过 --no_chat_template 切换回裸拼接
- chat template 缺失时给出 RuntimeError 提示
2026-06-18 16:35:05 +08:00
ViperEkura a4e5a8c81c feat: 新增 WSD 学习率调度器
- 支持 Warmup-Stable-Decay 三段式调度
- stable 阶段保持最高 lr,decay 阶段 sqrt 衰减
- 适用于持续预训练、SFT、RLHF 场景
2026-06-18 15:55:15 +08:00
ViperEkura 3e234c46f6 fix: 使用 threading.Event 替代裸 bool,补全公共 API
- scheduler 停止信号改用 threading.Event,跨解释器安全
- 移除 _fatal_error 和 check_health,异常仅用 logger.error 记录
- 补全 astrai/__init__.py,暴露所有主要模块
2026-06-18 15:38:35 +08:00
ViperEkura 7a04b1f8ce docs: replace shields.io endpoint badges with github/ direct badges
- Switch stars/forks/release to github/ endpoints to avoid pool exhaustion
- Add CI workflow badge for tests.yml
- Delete update-badges.yml (no longer needed)
- Remove remote gh-pages branch
2026-06-18 15:09:51 +08:00
ViperEkura a30e3d5114 fix: 修复 shields.io GitHub badge 因 token 耗尽而无法显示
- 新增 Action 每天及 push 时同步 badges 至 gh-pages
- README 改用 endpoint 格式指向自建静态 JSON, 不依赖 shields.io GitHub token 池
- 同步更新中英两份 README
2026-06-16 22:21:58 +08:00
ViperEkura 1818d06576 feat: 新增 IFD 数据质量评分工具, 移动 ppl 至 eval
- 计算指令遵循难度分数用于数据筛选
- IFD = 条件交叉熵 / 无条件交叉熵
- perplexity 移至 scripts/eval/
2026-06-16 22:03:45 +08:00
ViperEkura 4e8d1ee24e feat: 新增 IFEval 指令遵循评测
- 实现 25 种正则约束 verifier
- 将评测脚本从 scripts/tools/ 移至 scripts/eval/
2026-06-16 21:57:34 +08:00
ViperEkura fec376b0dd fix : 修复策略相关文件的类型注解与抽象方法体
- 修复 strategy.py 单元素 Union 与缺失的参数/返回类型注解
- 修复 train_context.py 8 个 default=None 字段缺 Optional 标记
- 修复 sample.py/packing.py/position_id.py 方法缺参数及返回类型注解
- 修复 factory.py _resolve_type/list_registered 缺类型注解
- 修复 train_config.py 裸 dict/list 缺泛型参数
- abstractmethod body 从 ... 改为 raise NotImplementedError
- feat : checkpoint meta.json 保存 TrainConfig 超参供人工查阅
2026-06-14 16:20:10 +08:00
ViperEkura a2512f8a5a fix : resume_dir 无权重文件时不强制加载,支持仅配置训练
- Checkpoint.load_any 统一处理 meta.json / model.safetensors / 无文件三种情况
- train_context.py 调用简化为单一路径,移除 load_model_weights 直接依赖
2026-06-13 15:40:14 +08:00
ViperEkura 457e16ea3c fix : val_loss 默认改为 None,日志跳过空值;val_dataloader 补 Optional 注解 2026-06-13 14:24:13 +08:00
ViperEkura daf627a6de fix : _save_log 前确保日志目录存在,防止跨进程反序列化后目录丢失 2026-06-12 15:39:54 +08:00
ViperEkura 445378667f feat : NEFTune 噪声注入 + label_smoothing 默认值修正
- Embedding.forward 训练时注入 randn 噪声,缩放系数 neftune_noise_alpha / sqrt(seq_len)
- TrainConfig.neftune_alpha 通过 config 传递(默认 0=关闭)
- TrainContextBuilder 将 config.neftune_alpha 写入 embed_tokens
- --neftune_alpha CLI 参数(典型值 5.0)
- label_smoothing 默认值 0.05 -> 0.0
2026-06-11 15:32:43 +08:00
ViperEkura 6ae1828449 refactor : 清理工厂和配置系统中的死代码与冗余抽象
- 删除 Registry 中未使用的 category/priority 字段,_entries 简化为直接存储类引用
- 修正 __init_subclass__ 避免叶子类(AutoRegressiveLM 等)创建空注册表
- 删除 5 个工厂的薄 create() 覆写,统一使用 BaseFactory.create(name, *args, **kwargs)
- 删除 3 处零调用的 available_types/available_strategies 别名死代码
- 删除零调用的 BaseModelConfig.to_file 死代码
- 将 BaseConfig.from_json/to_json 重命名为 from_file/to_file,消除与子类重复
- 移除两个 inference builder 中总是被覆写的 prompt_tokens=0
2026-06-07 11:39:50 +08:00
ViperEkura e7b18b7c03 refactor : BaseFactory 基类类型自动推导 + 移除冗余代码
- _validate_component 从 BaseFactory[T] 泛型参数自动解析基类类型,9 个子类覆写移除
- Registry 类内联到 BaseFactory._entries,移除未用的 list_by_category/list_by_priority
- _component_base 在 __init_subclass__ 时立即解析
- 数据集 4 个子类冗余 __init__ 移除
2026-06-06 21:23:41 +08:00
ViperEkura 9e31d4ef2b feat : BaseToolParser.feed 增加可选 token_ids 参数
- format_chunk ABC 改为 (token, **kwargs),body/token_ids 通过 kw 传入
- ProtocolHandler._handle_stream 逐 token encode 并透传
- Anthropic builder 用 **kwargs 吸收不使用的参数,零变更
- 新增 3 个 token_ids 参数测试
2026-06-06 11:19:30 +08:00
ViperEkura 52aa4d01d5 feat : 推理层增加 vLLM 风格工具调用解析
- 新增 BaseToolParser 抽象基类,定义 feed/parse_complete 流式接口
- 新增 SimpleJsonToolParser,解析 {"name":"...","arguments":{...}} 格式
- 新增 ToolParserFactory,基于 BaseFactory 实现可插拔注册
- 集成 parser 到 OpenAIResponseBuilder,支持流式/非流式工具调用
- 扩展 ChatMessage 和 ChatCompletionRequest,增加 tools/tool_choice 字段
- 重构 format_chunk 接口,传入累积文本支持全量重新解析
- 新增 74 个单元测试,覆盖扫描/查找/流式解析/完整解析/工厂
2026-06-06 08:54:10 +08:00
ViperEkura 986be957ec refactor : on_batch_begin 移入 accumulate 上下文 2026-06-06 01:19:21 +08:00
ViperEkura cf9c60841b docs : 按代码反向修正所有文档错误
- 更新预处理模块目录结构和类名(SectionedMaskBuilder)
- 修正 ResponseBuilder.prepare 签名(tokenizer → engine)
- 补全缺失的 CLI 参数、配置字段和数据键名
- 修正 README 中 download.py 的描述
2026-06-06 01:06:30 +08:00
ViperEkura 31bc7f5c2a refactor : pipeline 策略化拆分,消除 _flush if/else
- PackingStrategy / PositionIdStrategy / StoreWriter 独立文件 + Factory
- Pipeline._flush 零 if/else,纯编排
- SectionRenderer 从 SectionedMaskBuilder 分离
- OutputConfig.position_ids_mode 默认改为 ""none""
2026-06-06 00:45:33 +08:00
ViperEkura 3057741de9 refactor : 合并 data config docstring 并实现 BFD 打包策略
- 将 ProcessingConfig/OutputConfig 参数描述合并到类级 docstring

- Pipeline 支持 packing_strategy/truncation_mode,新增 bfd 打包
2026-06-05 17:41:51 +08:00
ViperEkura acd1103bd0 fix : 使用 bool 注意力掩码并支持打包 SFT 文档边界阻断
- 简化 process_attention_mask,通过广播返回 bool 掩码
- 新增 make_doc_boundary_mask 生成块对角因果掩码
- SFT strategy 传入文档边界掩码
2026-06-05 17:02:28 +08:00
ViperEkura dc7d2cfbca refactor : FastAPI 懒加载单例,消除模块级副作用
- import astrai.inference 不再在模块加载时创建 FastAPI 实例
- 路由移至 APIRouter;get_app() 首次调用时懒构造单例
- _create_engine 和 run_server 的 param_path 改为必填
- 更新测试改用 get_app() 替代模块级 app
2026-06-04 15:52:27 +08:00
ViperEkura b36a78c612 test : SFT 测试数据补全 position_ids 字段
- dummy_data 添加 position_ids 匹配 required_keys
2026-06-04 14:01:04 +08:00
ViperEkura 985d940db6 feat : 数据流水拼接策略支持 position_ids 预计算
- OutputConfig.position_ids_mode 三种模式控制边界策略
- pipeline._flush() 按配置生成扁平 position_ids 数组
- SFTDataset 在 __getitem__ 中返回 position_ids
- SFTStrategy 将 position_ids 传入 model.forward()
2026-06-04 13:56:19 +08:00
ViperEkura 5e73ca20aa feat : train CLI 新增 val_split/val_step/metrics/log 参数
- --val_split 从训练集按比例切分验证集
- --val_step 控制验证间隔 optimizer step 数
- --metrics 自定义日志指标列表,默认 loss lr
- --log_dir / --log_interval 控制日志输出目录和频率
2026-06-03 14:31:22 +08:00
ViperEkura 438dc10391 fix : MMLU eval 使用 chat template 格式匹配 SFT 训练数据
- 原 prompt 为纯文本格式,与 SFT chat template 不匹配导致模型输出随机
- 新增 apply_chat() 将 MMLU prompt 包装为 user/assistant 对话格式
- choice_text 改为单字母(去掉空格前缀)适配模板输出
- 5-shot 时 few-shot 示例作为独立 user/assistant 轮次插入
2026-06-03 11:59:42 +08:00
ViperEkura 615ba5d8ef feat : 新增 HumanEval pass@k 代码生成评测
- InferenceEngine.generate() 批量生成 n 个补全
- 正则提取函数体 + 停止符截断
- multiprocessing sandbox 执行 + timeout 保护
- 标准无偏 pass@k 公式 (1, 10, 100)
2026-06-03 10:52:32 +08:00
ViperEkura 02a7cb9fa0 feat : preprocessing 支持 DPO/GRPO 多输出格式
- InputConfig 新增 sources 字段驱动多输出映射
- SectionedMaskBuilder 提取 _process_sections/_build_multi 模板方法
- Pipeline 泛化 accumulate 逻辑处理多 key 结果
- 测试拆分为 config/builder/pipeline 三文件,纯函数风格
2026-06-03 10:32:10 +08:00
ViperEkura 9fe2121743 feat : TrainConfig 支持 val_split 从训练集自动切分验证集
- val_split 比例从 dataset 中划出验证集,用 random_seed 固定随机切分
- 若 val_dataset 已显式设置则跳过自动切分
2026-06-02 20:33:40 +08:00
ViperEkura 0422d6d38e refactor : 移除 LocalStrategy._clear_env 冗余清理
- setup_parallel 已覆盖所有环境变量写入,无需前置清空
2026-06-02 11:40:45 +08:00
ViperEkura 9b416c1bbb refactor : 并行启动 Strategy 模式重构,local_rank 解耦
- setup_parallel 接收 local_rank 参数,不再读环境变量推导
- TorchrunStrategy 从 env 读取 LOCAL_RANK,LocalStrategy 用 rank
- _detect_launcher() 分级检测替代内联 RANK 检查
- _run_single_rank 统一入口,消除 _run_single/_run_multi 重复
- 优雅退出:except BaseException 终止子进程并 re-join
- gradient_checkpointing_modules 判定提取到外部变量
2026-06-02 11:22:24 +08:00
ViperEkura d6899100ac
Merge pull request #17 from yegroup001/main
增加多机DDP
2026-06-02 10:29:07 +08:00
yegroup001 0deee48602 feat : 训练脚本新增 gradient_checkpointing 与多机 DDP 参数 2026-06-02 01:01:00 +08:00
yegroup001 746a1475b2 fix : 修复存储层 rglob 死锁、DDP LOCAL_RANK 绑定 2026-06-02 01:01:00 +08:00
ViperEkura 01ce1fb9e3 refactor : Pipeline 去除去重,ids 重命名为 sequence,泛型透传
- 移除 Pipeline 内置去重逻辑及 dedup_signature 工具函数
- 删除 ProcessingConfig.deduplicate 字段
- builder 返回 'sequence' 替代 'ids',与 dataset 层统一
- pipeline 纯透传,泛型处理任意 key 补齐默认值
2026-05-31 15:14:27 +08:00
ViperEkura 14f83cbdac perf : 预编译 Jinja2 Template,避免每次 render 重新构建 2026-05-31 14:50:16 +08:00
ViperEkura dbe5891201 refactor : 统一 SectionedMaskBuilder,支持可配置 dtype
- 三合一 MaskBuilder,移除 chat/instruction/text,统一为 sections 配置
- OutputConfig 增加 dtype 字段 (per-key,默认 int32)
- 移除 from __future__ import annotations
- 测试适配新配置格式
2026-05-31 14:24:10 +08:00
ViperEkura 2a65c3314c fix : 修复 created 时间戳、bin 多 shard 覆盖与文档遗漏
- openai.py/anthropic.py: created 从 0 改为 int(time.time())
- openai.py: ChatCompletionRequest 不支持参数非默认值时 warning
- pipeline.py: bin 多 shard 使用子目录避免静默覆盖
- storage.py: MmapStore/detect_format 支持多 shard 聚合加载
- architecture.md: mermaid 类图新增 Pipeline 类
- preprocessing.md: 新增多 shard 输出布局与 Python API 示例
- protocol.py: docstring "6 methods" 改为 "5 methods"
2026-05-30 23:03:42 +08:00
ViperEkura 1c2ff05a6d docs : 三轮深度验证修复文档与代码不一致
- architecture.md: 修正 unwrap_model 返回类型、Config Optional 标注、方法签名错误、类名错误
- training.md: 补充 on_error 回调、修正训练循环顺序、补全策略参数、model.safetensors
- inference.md: 修正 GenerationRequest 参数顺序、async 语法、KVCache 描述、temperature 约束
- dataflow.md: 补充 Store.load/fetch 流程、修正可选参数默认值
- README/params: 多 GPU 示例补全 --parallel_mode、文档表补充 preprocessing.md
- preprocessing.md: Chat 模式算法补全 BOS token 步骤
2026-05-30 21:41:06 +08:00
ViperEkura 31ae2deeba refactor : BaseConfig 提供 from_json/to_json,嵌套 config 自动反序列化
- from_json/to_json 上提至 BaseConfig,所有子类自动继承
- _coerce 新增 dict 到 BaseConfig 子类的递归反序列化,消除子类 from_dict 重载
- PipelineConfig 等子类仅声明字段,零样板代码
- 测试 tokenizer 改为自包含 BPE(含 chat template),不依赖 params/ 目录
- 特殊 token 改用 ASCII 字符,兼容所有平台
2026-05-30 21:04:19 +08:00
ViperEkura 69207e2c57 refactor : 基于声明式 JSON 配置的预处理管线重构
- 用工厂注册的 MaskBuilder(chat/instruction/text)替换硬编码的 _transform_* 方法
- mask 规则以 role-to-action 映射声明在配置中,与 chat_template 完全解耦
- 单次编码 + role-span 追踪替代两次编码 + 长度差计算 mask 的方式
- 支持多轮对话训练:所有 assistant 轮次参与训练,而非仅最后一轮
- 新建 astrai.preprocessing 包(builder.py + pipeline.py),删除 astrai/preprocess.py
- CLI 精简为 --config 参数,所有参数通过 PipelineConfig JSON 配置
- 新增 PipelineConfig、InputConfig、ProcessingConfig、OutputConfig dataclass
- 文档:assets/docs/preprocessing.md
- 27 个测试覆盖 mask builder、pipeline、配置序列化、工厂注册
2026-05-30 20:45:09 +08:00
ViperEkura 138c5bcc08 feat : 添加 JSONL 预处理管线
- Pipeline 模板, Reader 加 transform 加 Writer 可组合
- 自动检测 JSONL 格式, 支持 messages 文本 prompt 加 response 三种
- chat 数据通过 apply_chat_template 适配, 自动生成 loss_mask
- 输出对齐 Store 和 DatasetFactory, 直接用于训练
- 默认 bin 格式, CLI 入口 scripts/tools/preprocess.py
2026-05-30 17:12:42 +08:00
ViperEkura a923e0a23a fix : 修复 MMLU 评测脚本数据源和依赖
- 数据源改为 Berkeley data.tar(GitHub zip 不含数据文件)
- urllib 替换为 requests,支持代理下载
- zip 解压替换为 tar,增加目录 flatten 逻辑
- 添加 model.eval() 确保推理模式正确
2026-05-30 16:51:24 +08:00
ViperEkura f521a30b22 fix : FSDP 优化器顺序、温度除零、调度器静默死亡、ref模型设备
- executor: use_orig_params 硬编码 True,FSDP 不替换 Parameter 对象
- strategy: DPO/GRPO ref 模型创建后移到 device
- sample: TemperatureStrategy clamp 1e-8,engine 验证改为 >0
- scheduler: 异常不 re-raise 避免 daemon 静默死亡,stop() 发回调给 waiting 任务
2026-05-29 21:57:44 +08:00
ViperEkura d4451f6afb fix : 并行训练 state_dict 收集与训练/推理并发缺陷
- FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True
- DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict()
- CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict
- strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数
- TrainContextBuilder: 传递 model_fn + executor 到 strategy
- GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重
- TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态
- ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式)
2026-05-29 21:12:52 +08:00
ViperEkura a3275423a4 release : v1.3.7
Features
- FSDP parallel backend with zero-redundancy sharded training
- LoRA fine-tuning module with low-rank adapter injection and persistence
- NTK-Aware RoPE dynamic scaling, extending context window limit
- MMLU evaluation script for standardized model knowledge assessment
- load_json/load_safetensors broadcast mechanism for cross-node distributed loading

Refactors
- Storage layer refactored to Store pattern, removed Fetcher layer, supporting multi-segment data with explicit length
- Training backend refactored to Executor pattern (none/ddp/fsdp), decoupling parallel logic
- Inference protocol layer refactored to Strategy/Builder pattern with independent OpenAI/Anthropic responders
- Unified serialization layer, eliminating scattered I/O paths
- Removed JSONStore from data pipeline, unified to H5/Bin dual format
- Simplified _disable_random_init, moved scheduler into sync block
- Removed -> None return annotations, split FSDP parameters

Fixes
- Disabled DDP static_graph to prevent no_sync/backward conflict under PyTorch 2.7.1
- Checkpoint resume restores optimizer/scheduler state and sampler remaining length
- Unwrap DDP/FSDP on checkpoint save to avoid module. prefix
- start_epoch/start_batch determined by user args, no longer overridden by checkpoint
- Left padding in perplexity.py causing incorrect PPL with batch>1
- Storage multi-segment bug, switched JSON to JSONL
- Early abort on task_extend failure after decode, notify waiting tasks on scheduler crash

Docs
- Synced architecture/training/inference/dataflow/params docs to actual code

Tests
- Completed inference protocol layer unit test coverage
- Added LoRA module tests
- Filled storage layer test gaps
2026-05-29 17:46:03 +08:00
ViperEkura b37c3d000c docs : 同步文档与实际代码
- 移除 JSONStore 引用(该类不存在)
- 修正 Store.load() 和 DatasetFactory.load() 签名(无 tokenizer 参数)
- 修正 TrainContextBuilder.with_resume_dir() 命名
- 修正 Checkpoint config 字段和 meta.json 描述
- 修正 ProtocolHandler.handle() 异步签名
- 修正采样继承图(平行子类,非线性)
- 修正训练循环:回调移入 accumulate 块内
- 更新文档日期至 2026-05-28
2026-05-28 21:01:47 +08:00
ViperEkura 6031020e37 feat : load_json/load_safetensors 支持 broadcast,跨节点分布式加载
- load_json/load_safetensors/load_state_dict 新增 broadcast 参数
- broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank
- load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈
- 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json
- 参数注解 str | Path 统一为 Union[str, Path]
2026-05-28 20:44:58 +08:00
ViperEkura c424dfc293 feat : checkpoint 支持保存 config.json
- Checkpoint.save 写入独立的 config.json(模型架构参数)
- Checkpoint.load 读取 config.json,恢复时覆盖 context.model_config
- TrainContext 新增 model_config 字段,builder 从 resume_dir/config.json 加载
- BaseConfig.to_dict 支持 tuple 和嵌套 dataclass(如 LoRAConfig)
- 删除 _get_meta/_get_config wrapper,直接使用 load_json
2026-05-28 20:21:51 +08:00
ViperEkura 3a28e52e98 fix : start_epoch/start_batch 由用户参数决定,不再被 checkpoint 覆盖 2026-05-28 18:24:22 +08:00
ViperEkura e371908b54 fix : 保存 checkpoint 时 unwrap DDP/FSDP 避免 module. 前缀
- 移除 state_dict_fn 参数
- _save_checkpoint 中先 unwrap_model 再 state_dict()
2026-05-28 18:10:04 +08:00
ViperEkura 7c99da155c refactor: 删除数据流中的 JSONStore
- 移除 JSONStore 及相关函数,训练框架不再依赖 tokenizer
- Store 层只保留 H5Store 和 MmapStore 两种后端
2026-05-28 15:54:26 +08:00
ViperEkura 629e72385b fix : 修复存储层 bug,JSON 切换为 JSONL,补齐测试覆盖
- save_bin/load_bin: save_json/load_json 替换为直接 json.dump/json.load,修复致命 bug
- _normalize: 空 cum 列表 guard,防止 IndexError
- load_json: 改为仅支持 JSONL 逐行解析 (json.loads),移除 .json 支持
- detect_format: 只匹配 *.jsonl,不再匹配 *.json
- save_json: 输出扩展名改为 .jsonl
- GRPODataset.__getitem__: 补齐 .to(dtype=torch.long/bool) 与其他数据集一致
- load_bin: np.memmap mode='r+' 消除 PyTorch 不可写 tensor 警告
- 新增 16 个测试: bin roundtrip, mmap load, 空 key, JSONL 多行/文本, GRPO dtype/load, detect_format bin/jsonl, fetch multi-key/越界, json_to_bin 转换, DPO from JSONL, 显式 storage_type
2026-05-28 15:29:46 +08:00
ViperEkura 0a708fff24 docs : 更新架构文档与 storage 注释,同步 Store 重构
- architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore)
- architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系
- dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构
- storage.py: module docstring 改用缩进式注释风格
2026-05-28 14:36:18 +08:00
ViperEkura 6e150ea6d0 refactor : Storage 层重构为 Store,移除 Fetcher 中间层,支持多段数据与显式长度
- 合并 BaseStorage + MultiSegmentFetcher + BaseSegmentFetcher 三层为 Store ABC
- Store._data 直接持有 Dict[str, List[Tensor]],不做强制拼接避免 OOM
- _fetch_key 统一用 bisect 跨段切片,单段多段同一路径
- _length 显式存储(min total across keys),__len__ 返回 O(1)
- MmapStore/H5Store/JSONStore 统一走 _normalize() 注册分段并预计算累积长度
- 所有 I/O 函数 (save_h5/load_h5/json_to_bin 等) 保持不变
2026-05-28 14:23:49 +08:00
ViperEkura cb8dcb97ea refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储
- 删除所有 def 函数 -> None 返回值类型标注
- FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤
- 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载
- 新增 save_bin/load_bin/json_to_bin 工具函数
- detect_format 支持 bin 格式自动检测
2026-05-28 13:57:06 +08:00
ViperEkura 2d5dc93b3d fix : 修正类型标注与统一 CLI 参数命名
- AutoRegressiveLM.forward 返回类型标注 -> Dict[str, Tensor]
- EmbeddingEncoder 移除冗余 position_ids 自动创建
- CLI 脚本模型目录参数统一为 --param_path
2026-05-27 20:49:44 +08:00
ViperEkura 4145d35e3c refactor: 检查点加载重构,路径替代对象传递
- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串
- Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递
- TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch
- CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁
- serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast
- optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行
- pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
2026-05-27 20:15:29 +08:00
ViperEkura 34c6c45bd6 feat: 初步实现 MMLU 评测脚本
- 支持 few-shot (log-likelihood ranking) 与 zero-shot
- 自动下载 Hendrycks MMLU 数据集
- --device / --dtype 可配置,默认 GPU bf16
2026-05-26 20:23:31 +08:00
ViperEkura e9def84ce7 fix : perplexity.py left padding 导致 batch>1 时 PPL 计算错误 2026-05-26 19:59:57 +08:00
ViperEkura 836e02a166 docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项
- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式
- 修正训练循环 scheduler.step() 在 sync_gradients 块内
- 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联
- --parallel_mode CLI choices 加入 fsdp
- nprocs > 1 且 parallel_mode=none 时 raise error
2026-05-26 19:37:00 +08:00
ViperEkura b558e61f63 refactor: 简化 _disable_random_init,scheduler 移入同步块
- _disable_random_init: enable=False 提前返回,dict 推导替代空字典
- scheduler.step() 移入 sync_gradients 守卫内
2026-05-26 17:05:25 +08:00
ViperEkura 65ab69543b refactor: 统一序列化层,消除分散的 I/O 路径
- Checkpoint 改为 @dataclass,内聚 save/load 方法
- 提取 save_safetensors/load_safetensors/save_json/load_json 共享工具
- 新增 save_model/load_model_config/load_model_weights 模块函数
- automodel 和 lora 统一委托到 serialization 模块
2026-05-26 16:44:40 +08:00
ViperEkura 1d26aa2e93 fix: 禁用DDP static_graph避免PyTorch 2.7.1下no_sync与backward冲突
- static_graph=True时DDP.no_sync() + loss.backward()触发expect_autograd_hooks_内部断言
- PyTorch 2.7.1中no_sync上下文切换与静态图hook状态管理存在兼容性bug
- 将static_graph设为False恢复梯度累积正常执行
- find_unused_parameters保持False(模型无不参与计算的参数)
2026-05-26 15:08:01 +08:00
ViperEkura a548d4553e fix: 断点续训恢复优化器/调度器状态及采样器剩余长度
- 使用Checkpoint.load()替代手动加载model.safetensors,恢复optimizer/scheduler状态
- TrainContextBuilder从checkpoint.extra恢复优化器和调度器state_dict
- ResumableDistributedSampler.__len__返回剩余样本数而非总数
- 训练前对state_dict置空避免mp.spawn pickle 7GB大对象
2026-05-26 13:50:25 +08:00
ViperEkura dd1b39f435 fix: ProgressBar默认输出到stdout
- file参数默认值改为None, 内部用 or sys.stdout 兜底
- 清理inference API中未使用的import (Optional, time, field)
- 删除test_protocol中未使用的ctx变量
2026-05-26 13:27:05 +08:00
ViperEkura 94d6e713e9 test: 补充推理协议层单测覆盖
- StopChecker、GenContext、StopInfo 单测
- OpenAIResponseBuilder / AnthropicResponseBuilder 全部方法
- Anthropic 停止序列裁剪逻辑(含 unyielded 边界)
- GenerationRequest 参数校验含负值边界
- Scheduler prefill 短路验证
2026-05-26 00:21:52 +08:00
ViperEkura 47c37e4876 refactor: 推理协议层重构为策略/建造者模式
- ProtocolHandler 改为具体类,格式化委托给 ResponseBuilder
- 新增 api/protocols/ 目录,含 OpenAIResponseBuilder、AnthropicResponseBuilder
- GenContext、StopInfo 参数对象替代 StreamContext
- 消除 Builder 的实例可变状态(accumulated、_yielded)
- SSE 工具和停止检测收归 ProtocolHandler 统一管理
- prepare() 方法合并原来的 build_prompt、create_response_id
- 参数校验去重:仅 GenerationRequest.init 负责校验
- Prefill 阶段提前短路完全命中的缓存任务
2026-05-26 00:12:57 +08:00
ViperEkura 737585a32a feat: 新增NTK-Aware RoPE缩放支持
- RotaryEmbedding接受rope_scaling配置,自动计算scaled base
- AutoRegressiveLMConfig和EncoderConfig新增rope_scaling字段
2026-05-25 21:22:07 +08:00
ViperEkura a4688021bf feat: 新增LoRA微调模块
- LoRALinear基于register_parameter托管base weight,state_dict路径不变
- inject_lora/merge_lora/save_lora/load_lora完备封装
- 24个单元测试覆盖注入、合并、存取、边界场景
2026-05-25 20:15:31 +08:00
ViperEkura 7df6eb9211 feat: 新增FSDP并行后端
- FSDPExecutor通过**fsdp_kwargs直传FSDP参数
- unwrap_model同时支持DDP和FSDP
- parallel_mode新增fsdp选项
2026-05-25 19:43:14 +08:00
ViperEkura 82a3f2626f docs: 更新文档与代码同步(Executor/训练循环/参数)
- architecture.md: TrainConfig 移除旧 parallel_wrapper/state_dict_fn
- architecture.md: 新增 ExecutorFactory/BaseExecutor/DDPExecutor 等类图
- architecture.md: MLA 新增 use_qk_norm/q_norm/k_norm
- architecture.md: 新增 protocols 命名空间
- training.md: 修复训练循环 hook 名和 scheduler.step 位置
- training.md: 替换 parallel_wrapper 为 parallel_mode/executor.prepare
- training.md: 修复默认回调顺序和 Callback 生命周期表
- params.md: 新增 --parallel_mode 和 --start_method
2026-05-24 22:17:49 +08:00
ViperEkura 7fa69572c0 fix: 测试日志写入临时目录避免冗余文件 2026-05-24 20:54:59 +08:00
ViperEkura 3ab4f237e5 refactor: 重构训练后端为 Executor 模式
- backend.py → executor.py,BaseTrainingBackend → BaseExecutor
- 新增 NoneExecutor(单卡)和 DDPExecutor(DDP,world_size=1 自动降级)
- 新增 GradientState 分离梯度同步状态,AccumOptimizer/AccumScheduler 包裹拦截
- 新增 astrai/protocols.py:OptimizerProtocol/SchedulerProtocol 结构子类型
- TrainContext.backend → executor,TrainConfig 移除 parallel_wrapper/state_dict_fn,新增 parallel_mode/executor_kwargs
- 训练循环用 accumulate() 包裹,on_optimizer_step 命名约定=gate
- scripts/tools/train.py 移除 ddp_wrap/prepare_checkpoint,新增 --parallel_mode
2026-05-24 20:35:44 +08:00
ViperEkura 8cbf3f36e2 feat: 新增训练后端工厂框架
- BaseTrainingBackend 定义 prepare/accumulate/unwrap_model 抽象
- DDPTrainingBackend 支持全部 DDP 参数并通过 BackendFactory 注册
- unwrap_model 改为实例方法,由子类各自实现
2026-05-24 15:15:14 +08:00
ViperEkura 0594ce1017 perf: Muon step 改用 torch._foreach_* 批处理并移除 NS 迭代的冗余 bf16 转换 2026-05-23 19:50:12 +08:00
ViperEkura ff509ff39f fix: decode后task_extend失败时提前中止,scheduler崩溃时通知waiting任务 2026-05-20 19:23:13 +08:00
ViperEkura 785d65436c fix: 修复 to_dict list 类型丢失与 OpenAI stop 参数失效
- to_dict() 增加 list 类型序列化支持,metrics 等字段不再丢失
- OpenAIHandler 补充 get_stop_sequences/on_token,读取 request.stop 并检测停止序列
- 文档类图补充缺失字段、修正关系分类、ChatCompletionRequest 字段增加 Optional
2026-05-19 21:07:07 +08:00
ViperEkura 64be81b7b3 feat: ProgressBarCallback 支持日志行输出到 stdout
- serialization 和 metric_logger 的 timestamp 统一使用 ISO 8601 格式
- ProgressBarCallback 新增 log_interval/file 参数,默认输出到 sys.stdout
2026-05-19 19:12:38 +08:00
ViperEkura 45479b5731 feat: metric 参数通过 TrainConfig 传递
- TrainConfig 新增 log_dir/log_interval/metrics 配置字段

- metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤
2026-05-19 17:50:24 +08:00
ViperEkura e0a3337c22 docs: 更新视频链接 2026-05-19 17:34:01 +08:00
ViperEkura 812238060b fix: docker-compose UID/GID 添加默认值,修复 docker.sh logs 命令 2026-05-18 14:24:00 +08:00
ViperEkura 14b0d56197 fix: 修复无法创建子进程的问题
- mp.start_processes daemon=False
2026-05-18 09:40:32 +08:00
ViperEkura 6c8533f1d2 docs: 修正文档中类名/字段名与代码不一致之处
- ModelConfig → AutoRegressiveLMConfig, Transformer → AutoRegressiveLM
- 新增缺失类: EncoderConfig, EmbeddingEncoder, ConfigFactory, StorageFactory, ValidationCallback
- TrainConfig/TrainContext/ChatCompletionRequest 补充缺失字段
- dataflow.md 中 create_storage → StorageFactory.create
- 示例 --train_type=pt → seq 与代码一致
2026-05-17 21:02:21 +08:00
ViperEkura 2c2697390d feat: 新增 GradientCheckpointingCallback
- TrainConfig.gradient_checkpointing_modules 指定模块类型
- apply 递归遍历,兼容 DDP,不硬编码模型结构
- modules=None 时静默跳过,零开销
2026-05-17 18:21:05 +08:00
ViperEkura 7621f05d3f docs: AdamW beta 默认值改为 (0.9, 0.95)
- 与 Muon 优化器的 AdamW 子优化器保持一致
- 同步更新 train.py/training.md/params.md/README
2026-05-17 17:08:31 +08:00
ViperEkura 10ebd7211f feat: 新增 Muon 优化器
- 2D 参数用 Newton-Schulz 正交化 + Nesterov 动量更新
- 1D 参数用 AdamW 更新
- 支持 lr/momentum/weight_decay/ns_steps 配置
2026-05-17 16:44:03 +08:00
ViperEkura 42a391f0fb feat: 训练中新增验证循环
- TrainConfig 添加 val_dataset/val_step 字段
- TrainContext 添加 val_dataloader/val_loss 字段
- 新增 ValidationCallback 按 step 触发验证 + 训练结束时验证
- ProgressBar/MetricLogger 支持 val_loss 展示与记录
2026-05-17 16:12:42 +08:00
ViperEkura 97c7ac0f4f refactor: Transformer更名为AutoRegressiveLM并新增EmbeddingEncoder
- AutoRegressiveLM 注册名改为 autoregressive_lm
- 新增 EmbeddingEncoder 支持 mean/cls/last pooling
- ModelConfig 增加 pooling_type / normalize_embeddings 字段
- 导入、注释、测试全部同步更新
2026-05-17 15:29:20 +08:00
ViperEkura 8f1b32f2b6 fix: 移除多余 request 参数并增强 tokenizer 健壮性
- 路由和 _get_engine 不再需要 request 参数,直接引用模块级 app
- from_pretrained 增加文件完整性校验,缺 tokenizer.json 则抛 FileNotFoundError
- 移除 from_pretrained 中未使用的 **kwargs
2026-05-17 12:52:18 +08:00
ViperEkura c241a5dcef refactor: 优化并行训练配置与启动管理
- 配置新增 start_method 支持 spawn/fork/forkserver 选择
- 启动方式 mp.spawn 改为 mp.start_processes,支持 daemon=True
- validate() 改为基于 metadata 的反射式校验,不再硬编码字段列表
- CLI 新增 --start_method 参数
2026-05-17 12:33:10 +08:00
ViperEkura 44dab27fdc feat: 数据集加载时校验必填字段
- BaseDataset.required_keys 属性声明所需存储 key
- load() 时自动校验,缺失立即抛 KeyError
- SEQ/SFT/DPO/GRPO 各自声明 required_keys
2026-05-17 11:50:38 +08:00
ViperEkura a44fd22a99 fix: 修复训练与模型参数传递问题
- state_dict_fn 传入 CheckpointCallback,修复多卡 DDP 下 key 前缀丢失
- MLA 增加 use_qk_norm 支持,消除参数静默丢失
- moe_topk_method 统一命名为 topk_method
- checkpoint 回调移至最前
2026-05-17 11:20:13 +08:00
ViperEkura 8a11a7d444 fix: 修复训练脚本两处参数传递问题
- prepare_checkpoint 增加 DDP 判断,单卡时不访问 .module
- dpo_beta 改为 beta,对齐 DPOStrategy 参数名
2026-05-17 11:04:40 +08:00
ViperEkura 1d54491809 refactor: 改用递归子模块 init 替代统一 normal_(0.006)
- Embedding.reset_parameters: normal_(std=0.02)
- Linear.reset_parameters: kaiming_uniform_ + uniform_ bias
- Transformer._init_weights 通过 apply 递归调用子模块 reset_parameters
- 移除全局 normal_(0.006) 覆盖,各模块使用更合适的分布
2026-05-17 10:44:18 +08:00
ViperEkura ad9f4d9cf6 refactor: generate_ar 改用流式输出并去除冗余注释 2026-05-17 10:23:42 +08:00
ViperEkura e1638a7ade fix: 修正AdamW超参数默认值与文档示例
- 交换adamw_beta1/adamw_beta2默认值:beta1=0.95, beta2=0.99
- label_smoothing默认值改为0.05
- 文档示例统一更新:train_type=pt, weight_decay=0.01
- 移除文档中过时的strategy default标注
2026-05-16 22:46:17 +08:00
ViperEkura f91bfee33e refactor: Config序列化统一BaseConfig基类
- 新增astrai/config/base.py,提供to_dict/from_dict基类
- 统一命名:load/save → from_file/to_file
- Checkpoint.meta合并训练配置到meta.json
- sys.stderr.warn → warnings.warn
- from_file改为classmethod
2026-05-16 22:06:39 +08:00
ViperEkura d7a7f570ed refactor: 训练循环改为两重迭代并统一参数命名
- 训练循环从三重(epoch→batched→batch)改为二重(epoch→batch)
- batch_size → batch_per_device, accumulation_steps → grad_accum_steps
- scheduler 移入 step block 对齐 optimizer 更新步
- GradientClippingCallback 改用 on_step_begin 避免零梯度裁剪
- 移除 _train_impl 误导性的 -> Checkpoint 标注
- total_steps 修除为向下取整并精简为一行
- warmup_steps 改为 warmup_ratio (默认0.05)
2026-05-16 21:27:35 +08:00
ViperEkura 7dea929788 refactor: checkpoint 按 HF 方式存独立 .pt 文件,callback 接管恢复
- Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt
- meta.json 新增 timestamp
- CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性
- on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn
2026-05-16 18:29:04 +08:00
ViperEkura 026d1fc33d fix: total_steps 改用 ceiling 匹配实际步数
原公式全用 floor 少算 optimizer step,改用逐层 ceiling
(ceil_div via (a+b-1)//b)对齐 DDP sampler padding +
DataLoader drop_last=False 尾批 + batched 尾组截断。
2026-05-16 17:53:18 +08:00
ViperEkura 7242eedbf4 fix: 学习率调度按 optimizer step 计数并防止 warmup 越界
- total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率
- warmup_steps 用 min 截断,避免 lr_decay_steps 为负
2026-05-16 17:07:36 +08:00
ViperEkura 04c0dc7a47 refactor: Storage 改用工厂模式,server reload 接入 uvicorn
- 新增 StorageFactory(BaseFactory[BaseStorage]) 替代手写 dict 注册
- H5Storage / JSONStorage 通过 @StorageFactory.register 注册
- dataset.py 使用 StorageFactory.create() 替代 create_storage()
- 删除 create_storage / available_storage_types 死函数
- server.py reload 参数正式传入 uvicorn.run()
2026-05-16 17:00:26 +08:00
ViperEkura 48a53121ba refactor: 工厂 kwargs 过滤及组件参数清理
- BaseFactory.create() 按 __init__ 签名过滤多余 kwargs
- 移除 GQA/MLA/MLP/DeepSeekMoE 中多余的 **kwargs
- MLP/DeepSeekMoE 参数名统一为 dim_ffn
- scheduler max_seq_len 增加 None 显式判断
- 默认 max_prompt_len 提升至 2048
2026-05-16 16:47:41 +08:00
ViperEkura 0ba8c70ce1 fix: 修复 MLA 多个 bug 并缩小测试模型参数
- MLA kv_b_proj 输出维度和 q_rope 切分偏移修复
- 打通 MLA 配置从 ModelConfig 到 DecoderBlock 的传递路径
- rope_theta 配置不再被忽略,MLA 使用 qk_rope_head_dim
- tie_weight 使用 is True 避免 None 隐式生效
- norm_eps/rope base 类型标注修正
- 测试模型参数缩小 (dim=8, head_dim=4)
- 新增 6 种架构配置 × 2 场景的前向传播测试
2026-05-16 14:57:43 +08:00
ViperEkura 3d12a03909 docs : 拆分文档并补充类图缺失类和关系线
- 将 design.md 拆分为 architecture.md / inference.md / training.md
- 精简 dataflow.md 为纯数据管道
- 删除 design.md 和 introduction.md
- 更新 README.md 和 README-zh-CN.md 链接
- 补充 ChatMessage / AnthropicMessage 等 6 条孤立类关系线
- 补充 BaseModelConfig 和 TaskManager 两个缺失类
2026-05-15 23:38:26 +08:00
ViperEkura c169659611 docs: 修正 assets/docs/ 类图、数据流、参数文档及贡献指南
- design.md: 新增 ProtocolHandler/OpenAIHandler/AnthropicHandler 等缺失类
- design.md: 新增 Template Method、Storage 设计模式
- dataflow.md: 修正 GQA/MLA 为独立条目,补充 JSON 存储后端
- params.md: 标注 label_smoothing CLI 默认与 strategy 默认差异
- introduction.md: 修正 max_tokens 默认值 1024→2048
- CONTRIBUTING.md: 重写(纯 Python 无 conda、补充 CI 步骤与常见问题)
- .github/PULL_REQUEST_TEMPLATE.md: 修正 lint 命令,去除多余注释要求
- .github/ISSUE_TEMPLATE/bug_report.md: 修正 label(enhancement→bug)
2026-05-15 22:54:41 +08:00
ViperEkura e12f1a7ee5 feat: BaseModelConfig + DeepSeekMoE + 工厂模式替代 if/else
- BaseModelConfig: fields() 精确字段匹配 + 类型矫正 + 未知key警告
- DeepSeekMoE: 共享专家 + 路由专家 + top-K 门控
- AttnFactory/FFNFactory: 装饰器注册,DecoderBlock 零分支
- config 用 attn_type/ffn_type 驱动组件选择
2026-05-15 20:34:52 +08:00
ViperEkura ef25efffa2 refactor: 拆分 module.py 为 components 子包
- rope/linear/norm/embedding/mlp/attention/decoder_block 各自独立文件
- 依赖单向无循环
- 公开接口不变,外部无需修改
2026-05-15 20:08:36 +08:00
105 changed files with 11997 additions and 3511 deletions

View File

@ -2,7 +2,7 @@
name: Bug report
about: Create a report to help us improve
title: "[BUG]"
labels: enhancement
labels: bug
assignees: ''
---

View File

@ -16,9 +16,9 @@ Please delete options that are not relevant.
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
## Checklist:
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check --fix .`)
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] Code is self-documenting (no unnecessary comments)
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works

View File

@ -1,68 +1,100 @@
# Contributing to AstrAI
Thank you for your interest in contributing to AstrAI! This document provides guidelines and steps for contributing.
Thank you for your interest in contributing! This document provides step-by-step guidelines.
## How to Contribute
## Quick Start
### Reporting Issues
If you encounter a bug or have a feature request, please open an issue on GitHub. Include as much detail as possible:
- A clear description of the problem or request.
- Steps to reproduce (for bugs).
- Your environment (Python version, OS, etc.).
```bash
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
```
### Submitting Changes
1. **Fork** the repository.
2. **Clone** your fork:
```bash
git clone https://github.com/your-username/AstrAI.git
cd AstrAI
```
3. **Create a feature branch**:
```bash
git checkout -b feature/your-feature-name
```
4. **Make your changes**. Follow the code style guidelines below.
5. **Commit your changes** with a descriptive commit message:
```bash
git commit -m "Add: brief description of the change"
```
6. **Push** to your fork:
```bash
git push origin feature/your-feature-name
```
7. **Open a Pull Request** (PR) against the `main` branch of the upstream repository.
## Before You Commit
## Code Style
Run the following checks **in order** — CI will reject if any fail.
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
### 1. Format
- Run Ruff to format and lint (requires conda environment `nlp`):
```bash
conda run -n nlp ruff format .
conda run -n nlp ruff check --fix .
```
- The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`).
```bash
ruff format .
```
## Testing
> **Note**: `ruff format` may rename parameters (e.g. `mask``attn_mask`).
> Always review the diff after formatting.
If you add or modify functionality, please include appropriate tests.
### 2. Import sorting
- Run the test suite with:
```bash
conda run -n nlp python -u -m pytest
```
- Ensure all tests pass before submitting your PR.
```bash
ruff check . --select I
```
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
```bash
ruff check . --select I --fix .
ruff format . # re-format after fix
```
### 3. Run tests
```bash
python -u -m pytest tests/ -v
```
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
### 4. (Optional) Full pre-commit check
If you have Git Bash available:
```bash
bash scripts/pre_commit.sh
```
This runs format check, import sort check, and tests in one go.
## Commit Style
```
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
- bullet point body (each ~60 chars)
```
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
- **Subject line** ends with no period.
- **Body** uses bullet points starting with `-`.
- No `(scope)` parentheses.
## Common Issues
| Problem | Cause | Fix |
|---------|-------|-----|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
## Submitting Changes
1. Fork the repo.
2. Create a feature branch: `git checkout -b feat/my-feature`
3. Make changes following the steps above.
4. Commit with the commit style above.
5. Push: `git push origin feat/my-feature`
6. Open a Pull Request against `main`.
## Code Review
All submissions will be reviewed. We may request changes or discuss alternatives. Please be responsive to feedback.
- All PRs are reviewed. We may request changes.
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
- Ensure all tests pass.
## License
By contributing, you agree that your contributions will be licensed under the same [GPL-3.0 License](LICENSE) that covers the project.
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
---
If you have any questions, feel free to ask in the [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.
Happy contributing!
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.

View File

@ -1,7 +1,7 @@
# AstrAI Dockerfile - Multi-stage Build (Optimized)
# Build stage - use base image with minimal build tools
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS builder
FROM ubuntu:24.04 AS builder
WORKDIR /app
@ -18,7 +18,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
RUN python3.12 -m venv --copies /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Copy source code and install dependencies
# Copy source code and install (deps read from pyproject.toml)
COPY astrai/ ./astrai/
COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip \
@ -26,13 +26,14 @@ RUN pip install --no-cache-dir --upgrade pip \
--extra-index-url https://download.pytorch.org/whl/cu126
# Production stage
FROM nvidia/cuda:12.6.0-base-ubuntu24.04 AS production
FROM ubuntu:24.04 AS production
WORKDIR /app
# Install Python 3.12 runtime
# Install Python 3.12 runtime and healthcheck dependency
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.12 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy virtual environment from builder

191
README.md
View File

@ -9,9 +9,9 @@
<div align="center">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?label=Release&color=76bad9" alt="release">
<img src="https://img.shields.io/github/stars/ViperEkura/AstrAI?style=flat&label=Stars&color=76bad9" alt="stars">
<img src="https://img.shields.io/github/forks/ViperEkura/AstrAI?style=flat&label=Forks&color=76bad9" alt="forks">
</div>
<br>
@ -28,7 +28,8 @@
## 📖 Table of Contents
- [Features](#features)
- [Quick Start](#quick-start)
- [Getting Started](#getting-started)
- [Demo](#demo)
- [Documentation](#documentation)
- [Contributing](#contributing)
- [Community](#community)
@ -49,55 +50,117 @@
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
### Quick Start
### Getting Started
#### Installation
End-to-end walkthrough in 5 steps:
**1. Install**
```bash
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .
# pip install -e ".[dev]" # optional: dev dependencies (pytest, ruff)
```
For development dependencies:
**2. Download model**
```bash
pip install -e ".[dev]"
python scripts/demo/download.py # downloads 1B checkpoint to params/
```
#### Download Pre-trained Model
**3. Preprocess data**
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
Create `pretrain.json` (preprocessing config for `seq` strategy):
```json
{
"version": 1,
"input": {"sections": [{"field": "text", "action": "train"}]},
"preprocessing": {"max_seq_len": 2048},
"output": {"storage_format": "bin"}
}
```
```bash
python scripts/demo/download.py
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c pretrain.json
```
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
#### Train a Model
**4. Train**
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
Full reference at [Parameter Guide](assets/docs/params.md).
**5. Serve & query**
#### Generate Text
```bash
# Terminal 1: start server
python scripts/tools/server.py --param_path ./params --device cuda
# Terminal 2: query
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
### Demo
Check out the demos in the `scripts/demo/` folder:
```bash
# Download model weights (required before running demos)
python scripts/demo/download.py # model → params/
# Interactive streaming chat (multi-turn, maintains history)
python scripts/demo/stream_chat.py
# Type your message after >>, type !exit to quit
# Batch generation (5 hardcoded prompts, non-streaming)
python scripts/demo/generate_batch.py
# Single-prompt autoregressive streaming
python scripts/demo/generate_ar.py
```
All generation demos use `temperature=0.8`, `top_p=0.95`, `top_k=50`, `max_tokens=2048` by default and require `params/` to contain model weights (run `download.py` first).
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
---
See [Documentation](#documentation) for full references beyond the examples above.
#### Text Generation
Batch generation from a JSONL file:
```bash
python scripts/tools/generate.py \
--param_path /path/to/model \
--input_json_file /path/to/input.json \
--output_json_file /path/to/output.json
--param_path ./params \
--input_json_file input.jsonl \
--output_json_file output.jsonl
```
#### Docker
@ -111,9 +174,6 @@ docker build -t astrai:latest .
# Run with GPU support
docker run --gpus all -it astrai:latest
# Run with specific GPUs
docker run --gpus '"device=0,1"' -it astrai:latest
# Run inference server
docker run --gpus all -p 8000:8000 astrai:latest \
python -m scripts.tools.server --port 8000 --device cuda
@ -130,87 +190,42 @@ docker compose --profile cpu up -d
> **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
#### Start HTTP Server
#### HTTP API Examples
Start the inference server with OpenAI and Anthropic-compatible HTTP API:
Additional request examples beyond the [Getting Started](#getting-started) flow:
```bash
python -m scripts.tools.server --port 8000 --device cuda
```
Make requests:
```bash
# OpenAI-compatible
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
# OpenAI-compatible streaming
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Tell a story"}],
"stream": true,
"max_tokens": 500
}'
-d '{"messages":[{"role":"user","content":"Tell a story"}],"stream":true,"max_tokens":500}'
# Anthropic-compatible
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
-d '{"model":"astrai","system":"You are a helpful assistant.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
# Anthropic-compatible streaming with stop sequences
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["The end"]
}'
-d '{"model":"astrai","messages":[{"role":"user","content":"Write a story"}],"max_tokens":500,"stream":true,"stop_sequences":["The end"]}'
# Health check
curl http://localhost:8000/health
```
#### Demo
Check out the demos in the `scripts/demo/` folder:
```bash
# Download preprocessed data (required before running demos)
python scripts/demo/download.py
# Interactive streaming chat
python scripts/demo/stream_chat.py
# Batch generation
python scripts/demo/generate_batch.py
# Autoregressive generation
python scripts/demo/generate_ar.py
```
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd).
See [Inference Guide](assets/docs/inference.md) for SSE streaming format, error codes, and stats endpoint.
### Documentation
| Document | Description |
|----------|-------------|
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
| [Design Document](./assets/docs/design.md) | Framework architecture & module design |
| [Data Flow](./assets/docs/dataflow.md) | Data processing pipeline details |
| [Model Introduction](./assets/docs/introduction.md) | Model architecture & technical details |
| [CLI Reference](./assets/docs/params.md) | Parameters for all CLI tools (train, server, generate, preprocess) |
| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns |
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
### Contributing

View File

@ -15,9 +15,9 @@
<div align="center">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?label=Release&color=76bad9" alt="release">
<img src="https://img.shields.io/github/stars/ViperEkura/AstrAI?style=flat&label=Stars&color=76bad9" alt="stars">
<img src="https://img.shields.io/github/forks/ViperEkura/AstrAI?style=flat&label=Forks&color=76bad9" alt="forks">
</div>
<br>
@ -34,7 +34,8 @@
## 📖 目录
- [特性](#特性)
- [快速开始](#快速开始)
- [快速上手](#快速上手)
- [演示](#演示)
- [文档](#文档)
- [贡献](#贡献)
- [社区](#社区)
@ -55,55 +56,117 @@
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API开箱即用。
### 快速开始
### 快速上手
#### 安装
端到端演示,只需 5 步:
**1. 安装**
```bash
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .
# pip install -e ".[dev]" # 可选开发依赖pytest, ruff
```
安装开发依赖:
**2. 下载模型**
```bash
pip install -e ".[dev]"
python scripts/demo/download.py # 下载 1B 检查点到 params/
```
#### 下载预训练模型
**3. 预处理数据**
下载预训练模型权重1B 双语检查点)到 `params/` 目录:
创建 `pretrain.json``seq` 策略的预处理配置):
```json
{
"version": 1,
"input": {"sections": [{"field": "text", "action": "train"}]},
"preprocessing": {"max_seq_len": 2048},
"output": {"storage_format": "bin"}
}
```
```bash
python scripts/demo/download.py
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c pretrain.json
```
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`
#### 训练模型
**4. 训练**
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
完整参数列表见[参数说明](./params.md)。
**5. 启动服务并调用**
```bash
# 终端 1启动服务
python scripts/tools/server.py --param_path ./params --device cuda
# 终端 2发起请求
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"你好"}],"max_tokens":512}'
```
### 演示
查看 `scripts/demo/` 文件夹中的演示:
```bash
# 下载模型权重(运行演示前必需)
python scripts/demo/download.py # model → params/
# 交互式流式聊天(多轮对话,保持历史记录)
python scripts/demo/stream_chat.py
# 在 >> 后输入消息,输入 !exit 退出
# 批量生成5 条硬编码提示词,非流式)
python scripts/demo/generate_batch.py
# 单条提示词自回归流式生成
python scripts/demo/generate_ar.py
```
所有生成演示默认使用 `temperature=0.8`、`top_p=0.95`、`top_k=50`、`max_tokens=2048`,需要 `params/` 目录包含模型权重(请先运行 `download.py`)。
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
---
更多选项请参考[文档](#文档)。
#### 文本生成
从 JSONL 文件批量生成:
```bash
python scripts/tools/generate.py \
--param_path /path/to/model \
--input_json_file /path/to/input.json \
--output_json_file /path/to/output.json
--param_path ./params \
--input_json_file input.jsonl \
--output_json_file output.jsonl
```
#### Docker
@ -117,9 +180,6 @@ docker build -t astrai:latest .
# 启用 GPU 运行
docker run --gpus all -it astrai:latest
# 指定特定 GPU
docker run --gpus '"device=0,1"' -it astrai:latest
# 运行推理服务
docker run --gpus all -p 8000:8000 astrai:latest \
python -m scripts.tools.server --port 8000 --device cuda
@ -136,87 +196,42 @@ docker compose --profile cpu up -d
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`
#### 启动 HTTP 服务
#### HTTP API 示例
启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API
除[快速上手](#快速上手)流程外,更多请求示例
```bash
python -m scripts.tools.server --port 8000 --device cuda
```
发起请求:
```bash
# OpenAI 兼容
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "你好"}],
"max_tokens": 512
}'
# OpenAI 兼容流式
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "讲个故事"}],
"stream": true,
"max_tokens": 500
}'
-d '{"messages":[{"role":"user","content":"讲个故事"}],"stream":true,"max_tokens":500}'
# Anthropic 兼容
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "你是一个乐于助人的助手。",
"messages": [{"role": "user", "content": "你好"}],
"max_tokens": 512
}'
-d '{"model":"astrai","system":"你是一个乐于助人的助手。","messages":[{"role":"user","content":"你好"}],"max_tokens":512}'
# Anthropic 兼容流式并设置停止序列
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "写个故事"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["结束"]
}'
-d '{"model":"astrai","messages":[{"role":"user","content":"写个故事"}],"max_tokens":500,"stream":true,"stop_sequences":["结束"]}'
# 健康检查
curl http://localhost:8000/health
```
#### 演示
查看 `scripts/demo/` 文件夹中的演示:
```bash
# 下载预处理数据(运行演示前必需)
python scripts/demo/download.py
# 交互式流式聊天
python scripts/demo/stream_chat.py
# 批量生成
python scripts/demo/generate_batch.py
# 自回归生成
python scripts/demo/generate_ar.py
```
观看 [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) 上的视频演示。
SSE 流式格式、错误码和统计端点详见[推理文档](./inference.md)。
### 文档
| 文档 | 说明 |
|------|------|
| [参数说明](./params.md) | 训练与推理参数配置 |
| [设计文档](./design.md) | 系统架构与模块设计 |
| [数据流程](./dataflow.md) | 数据处理管道详解 |
| [模型介绍](./introduction.md) | 模型架构与技术细节 |
| [CLI 参考](./params.md) | 所有 CLI 工具参数(训练、服务、生成、预处理) |
| [架构文档](./architecture.md) | 系统架构、类图与设计模式 |
| [训练文档](./training.md) | 训练循环、策略与公式 |
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
### 贡献

1212
assets/docs/architecture.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,237 +1,109 @@
# AstrAI Data Flow Documentation
# Data Flow
This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
This document describes the data pipeline: from raw text to model input tensors. For creating preprocessing configs, see [Preprocessing Guide](preprocessing.md).
## Contents
- [Overview](#overview)
- [Data Preparation](#data-preparation) — tokenization, format detection, backends
- [Data Keys by Training Type](#data-keys-by-training-type)
- [Dataset Architecture](#dataset-architecture)
- [Sampler](#sampler)
- [DataLoader](#dataloader)
## Overview
AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors
## Data Flow Diagram
```mermaid
flowchart LR
subgraph A[Data Preparation]
direction TB
A1[Raw Text] --> A2[AutoTokenizer]
A2 --> A3[Tokenized .h5 files]
A3 --> A4[BaseDataset]
A4 --> A5[ResumableDistributedSampler]
A5 --> A6[DataLoader]
end
subgraph B[Training]
direction TB
B1[DataLoader] --> B2[BaseStrategy]
B2 --> B3[Transformer Forward]
B3 --> B4[Loss + Backward]
B4 --> B5[Gradient Accumulation]
B5 -->|every accum_steps| B6[Optimizer Step]
B6 --> B7[LR Scheduler]
B7 -->|next batch| B2
B6 --> B8[CheckpointCallback]
end
subgraph C[Inference]
direction TB
C1[Checkpoint] --> C2[AutoModel]
C1 --> C3[AutoTokenizer]
C2 --> C4[InferenceEngine]
C3 --> C4
C4 --> C5[InferenceScheduler]
C5 --> C6[Transformer Forward]
C6 --> C7[sample]
C7 --> C8{End?}
C8 -->|No| C6
C8 -->|Yes| C9[Generated Text]
end
A --> B
B --> C
```
JSONL Lines → Pipeline (mask builder) → Tokenized Tensors
.h5 or .bin storage
Store.load()
Store.fetch(begin, end, keys)
BaseDataset.__getitem__(idx)
Sampler → DataLoader → Training / Inference
```
## Detailed Module Descriptions
## Data Preparation
### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`)
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
- **`Checkpoint`**: Encapsulates model state dict + epoch + iteration; uses safetensors
### Tokenization
### 2. Dataset Module
The `Pipeline` reads JSONL lines, applies the mask builder (see [Preprocessing](preprocessing.md)), and produces flat token sequences:
#### 2.1 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class for windowed sequence sampling
- **`BaseSegmentFetcher` / `MultiSegmentFetcher`**: Fetch tensor segments by index range
- **`DatasetFactory`**: Creates dataset instances by `train_type` (`seq`, `sft`, `dpo`, `grpo`)
- Data keys: `"sequence"` (SEQ), `"loss_mask"` (SFT), `"chosen_mask"/"rejected_mask"` (DPO), `"masks"` (GRPO)
#### 2.2 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Tracks `epoch` and `iter` for breakpoint resume; supports shuffle and drop_last
### 3. Model Module
#### 3.1 Transformer / AutoModel
- **`AutoModel`**: Base class with `from_pretrained()` / `save_pretrained()`
- **`Transformer`**: Decoder-only architecture, registered via `@AutoModel.register('transformer')`
- Embedding → N×DecoderBlock → RMSNorm → Linear lm_head
- RoPE position encoding, optional weight tying
#### 3.2 Submodules (`module.py`)
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
- **`RMSNorm`**: Layer normalization
### 4. Training Module
#### 4.1 Training Context (`train_context.py`)
- **`TrainContext`**: Dataclass holding model, optimizer, dataloader, strategy, scheduler, checkpoint state
- **`TrainContextBuilder`**: Builder pattern — takes checkpoint for resume, builds all components
#### 4.2 Trainer (`trainer.py`)
The training loop is nested: **epoch****batch** (with step phase interspersed):
```
on_train_begin
on_epoch_begin
for each accumulation window of batches: ← step phase
on_step_begin
for each batch in window: ← batch phase
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
iteration += 1
on_step_end
optimizer.step() → zero_grad
on_epoch_end
on_train_end
```python
# Per JSONL line: messages → chat template → token IDs + loss mask
tokens = tokenizer.encode(rendered_text) # List[int]
loss_mask = [0, 0, 0, 1, 1, 1, 1, 1, 1] # 0=masked, 1=train
# Stored as flat tensors, packed with other lines by packing strategy
```
Key points:
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
- `on_batch_*` fires every batch, wrapping loss computation
- `GradientClippingCallback` fires on `on_step_end`
- LR scheduler steps inline (no `SchedulerCallback` class)
The output `meta.json` records the storage format, key names, dtype, total token count, and tensor shapes for each shard.
#### 4.3 Strategy (`strategy.py`)
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
- **`SFTStrategy`**: Supervised fine-tuning with loss masking
- **`DPOStrategy`**: Direct Preference Optimization with reference model
- **`GRPOStrategy`**: Group Relative Policy Optimization with clipped ratio
### Format Detection
#### 4.4 Scheduler (`schedule.py`)
- **`CosineScheduler`**: Cosine decay + linear warmup
- **`SGDRScheduler`**: Cosine annealing with warm restarts
- Created by `SchedulerFactory` and bound to optimizer
`detect_format(load_path)` inspects the directory:
#### 4.5 Callbacks
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
- **`ProgressBarCallback`**: tqdm progress display
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
- If `*.h5` files exist → `"h5"` (HDF5 backend)
- If `*.bin` + `meta.json` files exist → `"bin"` (memory-mapped backend)
### 5. Inference Module
### Store Backends
#### 5.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Facade over scheduler; provides `generate()`, `generate_with_request()`, `generate_async()`
- Accepts `prompt: str | List[str]`, returns generator (stream) or string (non-stream)
#### 5.2 Scheduler 4-Phase Loop (`scheduler.py`)
Background thread runs continuously:
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
1. Cleanup → Remove finished tasks, free KV cache pages
2. Refill → Pop from waiting_queue, alloc pages, add to active
3. Prefill → Group active tasks by prompt_len, run full forward pass
4. Decode → Pick largest same-position group, run single-token forward
StoreFactory.create("h5") → H5Store
StoreFactory.create("bin") → MmapStore
```
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers
- **`sample()`**: Temperature → top-k → top-p → multinomial
**H5Store**: Reads HDF5 files, supports `share_memory_()` for multi-process DataLoader workers (copies tensors to shared memory).
#### 5.3 Server (`server.py`)
- FastAPI with OpenAI `/v1/chat/completions` and Anthropic `/v1/messages` endpoints
- Streaming via SSE, health check at `/health`, stats at `/stats`
**MmapStore**: Memory-maps `.bin` files. OS page cache sharing is native — no explicit `share_memory_()` needed. Uses `torch.from_numpy(np.memmap(...))`.
### 6. Tokenizer Module
Both backends normalise tensors into `Store._data[Dict[str, List[Tensor]]]` + `Store._cum[Dict[str, List[int]]]` (cumulative lengths for bisect-based indexing).
- **`AutoTokenizer`**: Wraps HuggingFace tokenizers (BBPE); `encode`/`decode`/`apply_chat_template`
- **`ChatTemplate`**: Jinja2-based template rendering for multi-turn chat
## Data Keys by Training Type
### 7. Factory & Parallel
| Type | Storage Keys |
|------|-------------|
| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) |
| `sft` | `sequence`, `loss_mask`, `position_ids` |
| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` |
| `grpo` | `prompts`, `responses`, `masks`, `rewards` |
- **`Registry` / `BaseFactory`**: Decorator-based component registration
- **`spawn_parallel_fn`**: Multi-process DDP launcher with NCCL backend
- **`ParallelModel` / `ColumnParallelLinear` / `RowParallelLinear`**: Tensor model parallelism
## Dataset Architecture
## Training Data Flow — Detailed Steps
```
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
→ BaseDataset.load(load_path, storage_type=None)
→ detect_format(load_path)
→ StoreFactory.create(storage_type)
→ Store.load(load_path)
→ H5Store._normalize() / MmapStore._normalize()
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx)
→ get_index(idx) → [begin, end)
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
```
1. **Data Preparation**
- Raw text → token IDs via `AutoTokenizer.encode()`
- Save as `.h5` files (groups of tensor lists per data key)
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
2. **Dataset Loading**
- `BaseDataset.load()` calls `load_h5()`, builds `MultiSegmentFetcher`
- Sliding window of `window_size` with `stride` determines sample boundaries
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
3. **Sampling & Batching**
- `ResumableDistributedSampler` produces shuffled index sequences
- `DataLoader` fetches `[batch_size, window_size]` tensors via `__getitem__`
## Sampler
4. **Strategy Forward**
- Strategy receives batch, calls `Transformer.forward()` for logits
- Computes task-specific loss (cross-entropy, DPO, GRPO)
`ResumableDistributedSampler` supports checkpoint-aware distributed sampling:
5. **Backward & Accumulation**
- `loss = raw_loss / accumulation_steps`
- `loss.backward()` accumulates gradients
- Every `accumulation_steps` batches: `optimizer.step()``zero_grad()`
- Every batch: `scheduler.step()` updates learning rate
- Tracks `start_epoch` / `start_iter` for resume
- Shuffle via `torch.Generator(seed + epoch)`
- Per-replica index slicing for DDP
6. **Checkpoint**
- `CheckpointCallback` saves `model.state_dict()` + metadata to safetensors at `ckpt_interval` iterations
- Does NOT save optimizer/scheduler state (resume resets those)
## DataLoader
## Inference Data Flow — Detailed Steps
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
1. **Model Loading**
- `AutoModel.from_pretrained(path)` loads weights from safetensors
- `torch.inference_mode()` wraps generation
2. **Prompt Construction**
- Messages → `apply_chat_template(messages, tokenize=False)` → prompt string
- `tokenizer.encode(prompt)` → token IDs (truncated to `max_prompt_len`)
3. **Continuous Batching Loop**
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
- **Decode**: Pick position group with most tasks, single-token forward:
- Model forward → `logits``sample()` → next token ID
- Append to `output_ids`, update `output_tokens`
- `PagePool.task_alloc()` allocates pages as needed
- `stream_callback(token)` for streaming clients
4. **Output**
- `tokenizer.decode(output_ids)` → text
- Return to caller (streaming: token-by-token; non-streaming: complete string)
## Checkpoint & Serialization
- **Training Checkpoint**: safetensors weights + epoch/iteration metadata. Optimizer/scheduler state is NOT persisted.
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
> Document Update Time: 2026-05-14
> Document Update Time: 2026-06-19

View File

@ -1,779 +0,0 @@
## 1. Why I Created This Project
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, and the training code is open source!
## 2. System Architecture
```mermaid
classDiagram
namespace config {
class ModelConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+bool tie_weight
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+load(config_path) ModelConfig
+save(config_path)
}
class TrainConfig {
+nn.Module model
+str strategy
+Dataset dataset
+Callable optimizer_fn
+Callable scheduler_fn
+int n_epoch
+int batch_size
+int accumulation_steps
+float max_grad_norm
+int start_epoch
+int start_batch
+str ckpt_dir
+int ckpt_interval
+int random_seed
+int num_workers
+int prefetch_factor
+bool pin_memory
+int nprocs
+str backend
+str master_addr
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+str device_type
+dict extra_kwargs
+validate()
}
}
namespace dataset {
class BaseDataset {
+int window_size
+int stride
+BaseStorage storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
}
class SEQDataset {
+__getitem__(index) Dict
}
class SFTDataset {
+__getitem__(index) Dict
}
class DPODataset {
+__getitem__(index) Dict
}
class GRPODataset {
+__getitem__(index) Dict
}
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+MultiSegmentFetcher _fetcher
+keys (property)
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
class ResumableDistributedSampler {
+int epoch
+int iter
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
}
namespace serialization {
class Checkpoint {
+dict state_dict
+int epoch
+int iteration
+save(save_dir)
+load(save_dir) Checkpoint
}
}
namespace model {
class AutoModel {
+ModelConfig config
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
}
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor
+load_state_dict(state_dict)
+state_dict()
}
class DecoderBlock {
+GQA attention
+RMSNorm input_norm
+MLP mlp
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
}
class GQA {
+int n_heads
+int n_kv_heads
+int head_dim
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLA {
+int n_heads
+int n_kv_heads
+int head_dim
+int kv_lora_rank
+int qk_nope_head_dim
+int qk_rope_head_dim
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+RMSNorm kv_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLP {
+Linear up, gate, down
+forward(x) Tensor
}
class RMSNorm {
+Parameter weight
+float norm_eps
+forward(x) Tensor
}
class Linear {
+Parameter weight
+Parameter bias
+forward(x) Tensor
}
class RotaryEmbedding {
+int dim
+int max_len
+float base
+forward(x, position_ids=None) Tensor
}
class Embedding {
+Parameter weight
+forward(x) Tensor
}
}
namespace tokenize {
class AutoTokenizer {
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
+save_pretrained(save_path)
}
class ChatTemplate {
+String template_str
+render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate
}
}
namespace factory {
class Registry {
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List[str]
}
class BaseFactory {
+Registry _registry
+register(name, category, priority) decorator
+create(name, *args, **kwargs) T
+list_registered() list
}
}
namespace trainer {
class Trainer {
+TrainConfig train_config
+List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback]
}
class TrainContext {
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+int epoch
+int iteration
+float loss
+int world_size
+int rank
}
class TrainContextBuilder {
+TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder
+build() TrainContext
}
class BaseStrategy {
+nn.Module model
+str device
+compute_loss(batch) Tensor
}
class StrategyFactory {
+Registry _registry
+register(name) decorator
+create(model, train_type, device, **kwargs) BaseStrategy
}
class SEQStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class SFTStrategy {
+float label_smoothing
+compute_loss(batch) Tensor
}
class DPOStrategy {
+nn.Module ref_model
+float beta
+str reduction
+compute_loss(batch) Tensor
}
class GRPOStrategy {
+nn.Module ref_model
+float clip_eps
+float kl_coef
+int group_size
+str reduction
+int sync_interval
+compute_loss(batch) Tensor
}
class BaseScheduler {
+get_lr() List[float]
+step()
}
class SchedulerFactory {
+Registry _registry
+register(name) decorator
+create(optimizer, schedule_type, **kwargs) BaseScheduler
}
class CosineScheduler {
+int warmup_steps
+int lr_decay_steps
+float min_rate
}
class SGDRScheduler {
+int warmup_steps
+int cycle_length
+float min_rate
+int t_mult
}
class TrainCallback {
+on_train_begin(context)
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
}
class CheckpointCallback {
+str save_dir
+int interval
+_save_checkpoint(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
}
class ProgressBarCallback {
+int num_epoch
+on_epoch_begin(context)
+on_batch_end(context)
+on_epoch_end(context)
}
class MetricLoggerCallback {
+str log_dir
+int save_interval
+on_batch_end(context)
+on_train_end(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
}
namespace inference {
class InferenceEngine {
+nn.Module model
+AutoTokenizer tokenizer
+InferenceScheduler scheduler
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict
+shutdown()
}
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+KVCache _page_cache
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+TaskManager _task_mgr
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id)
+start()
+stop()
+get_stats() Dict
}
class Allocator {
+int _free_mask
+int refs_count
+LRU _lru
+alloc() int
+free(idx, keep_cached)
+inc_ref(idx)
+touch(idx)
+ref_count(idx) int
}
class PrefixCache {
+int _page_size
+evict(page_idx)
+has_page(idx) bool
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class PagePool {
-Allocator _alloc
-PrefixCache _prefix
+alloc() int
+free(idx)
+inc_ref(idx)
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class Storage {
+int n_layers
+int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache
+Tensor v_cache
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
}
class KVCache {
-PagePool _pool
-Storage _storage
-TaskTable _table
+int page_size
+task_alloc(task_id, prompt_ids) bool
+task_free(task_id)
+task_extend(task_id, pos) bool
+task_cached(task_id) int
+task_record_hashes(task_id, prompt_ids, start_logical_page)
+make_table_tensor(task_ids, device) Tensor
+bind(page_table, total_len) KvcacheView
}
class KvcacheView {
-Storage _storage
+Tensor _page_table
+int _total_len
+write(layer_id, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
class TaskTable {
+set(task_id, page_table, cached)
+get(task_id) List[int]
+get_cached(task_id) int
+get_ref(task_id) List[int]
+pop(task_id) Tuple[List[int], int]
+table_tensor(task_ids, device) Tensor
}
class Task {
+str task_id
+List prompt_ids
+int max_tokens
+float temperature
+float top_p
+int top_k
+TaskStatus status
+List output_ids
+int input_tokens
+int output_tokens
+float arrival_time
+float finish_time
+Callable stream_callback
+int next_pos
+is_finished(stop_ids) bool
}
class TaskStatus {
<<enumeration>>
PENDING
RUNNING
FINISHED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+int top_k
+float top_p
+float temperature
+Optional[int] max_tokens
+bool stream
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
}
class GenerateResult {
+List[Tuple[int, str]] tokens
+List[str] results
+List[bool] _done
+append(token, idx)
+get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
+wait_completion()
}
class ChatMessage {
+str role
+str content
}
class ChatCompletionRequest {
+List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[str] stop
+Optional[int] n
}
}
namespace parallel {
class Functions {
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
}
class ParallelModel {
+dist.ProcessGroup process_group
+int rank
+int world_size
}
class ColumnParallelLinear {
+forward(x) Tensor
}
class RowParallelLinear {
+forward(x) Tensor
}
}
%% Relationships
TrainConfig --> BaseDataset : uses
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
BaseStrategy <|-- DPOStrategy
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : uses
Trainer --> TrainContextBuilder : uses
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
TrainContextBuilder --> StrategyFactory : uses
Checkpoint ..> Checkpoint : serializes
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
PagePool --> Allocator : composes
PagePool --> PrefixCache : composes
KVCache --> PagePool : composes
KVCache --> Storage : composes
KVCache --> TaskTable : composes
KvcacheView --> Storage : wraps
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
InferenceEngine --> GenerateResult : creates
InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> KVCache : uses
InferenceScheduler --> Transformer : uses
Task --> TaskStatus : uses
InferenceEngine --> Transformer : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseDataset --> BaseStorage : uses
MultiSegmentFetcher --> BaseSegmentFetcher : uses
AutoModel <|-- Transformer
AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses
Transformer --> RotaryEmbedding : uses
Transformer --> Embedding : uses
DecoderBlock --> GQA : uses
DecoderBlock --> MLP : uses
DecoderBlock --> RMSNorm : uses
TrainContextBuilder --> ResumableDistributedSampler : creates
ResumableDistributedSampler --> BaseDataset : samples
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
BaseFactory <|-- AutoModel
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
```
### Module Overview
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
### Design Patterns
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Strategy** | `BaseStrategy`, `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy`, `StrategyFactory` | Flexible training strategy switching, supports SEQ/SFT/DPO/GRPO |
| **Builder** | `TrainContextBuilder` | Chain-building training context, step-by-step initialization of components |
| **Factory** | `StrategyFactory`, `SchedulerFactory`, `DatasetFactory`, `CallbackFactory`, `BaseFactory` | Decorator registration mechanism, dynamically create training strategies, schedulers, datasets, and callbacks |
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
8. **Scheduler Support**: `SchedulerFactory` creates learning rate schedulers (CosineScheduler, SGDRScheduler)
9. **AutoModel Loading**: `AutoModel.from_pretrained()` dynamically loads model based on `config.json` model_type, uses `Registry` pattern for model type registration
## 3. Training Process
The common training process for large language models (LLM) typically includes three stages: **Pre-training (SEQ)**, **Supervised Fine-Tuning (SFT)**, and **Reinforcement Learning from Human Feedback (DPO/GRPO)**. This system is designed to support seamless end-to-end flow, achieving efficient switching and state management of different training stages through modular strategies.
### Core Formulas
**Pre-training (SEQ):**
$$
L_{\text{PT}} = - \sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
**SFT:**
$$
L_{\text{SFT}} = - \sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$
**DPO:**
$$
L_{\text{DPO}} = -\mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]
$$
**GRPO:**
GRPO (Group Relative Policy Optimization) computes advantages from multiple responses to the same prompt, then optimizes using a PPO-style clipped objective:
$$
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
$$
Where $r_i$ is the reward for the $i$-th response, $\mu$ and $\sigma$ are the mean and standard deviation of group rewards.
$$
L_{\text{GRPO}} = -\mathbb{E} \left[ \min\left( \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} \cdot A, \text{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}, 1-\epsilon, 1+\epsilon\right) \cdot A \right) \right] + \lambda \cdot D_{KL}
$$
The KL divergence term uses mean squared error approximation:
$$
L_{KL} = \lambda \cdot \mathbb{E} \left[ (\log \pi_\theta - \log \pi_{\text{ref}})^2 \right]
$$
The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-05-14

249
assets/docs/inference.md Normal file
View File

@ -0,0 +1,249 @@
# Inference
## Contents
- [KV Cache](#kv-cache)
- [KVCache System](#kvcache-system)
- [Continuous Batching](#continuous-batching)
- [Sampling](#sampling-strategy-pattern)
- [Protocol Handlers](#protocol-handlers-strategy-pattern)
- [Engine & GenerateResult](#engine--generateresult)
- [HTTP API](#http-api) — endpoints, SSE, errors, stats
- [Engine API](#engine-api)
## KV Cache
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
$$
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
$$
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
## KVCache System
Six classes (plus two helpers) working together:
```
KVCache (facade)
├── PagePool orchestrates page allocation + prefix matching
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
## Continuous Batching
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
```
1. Cleanup → Remove finished tasks, free KV pages
2. Refill → Pop from waiting_queue, task_alloc pages, activate
3. Prefill → Group by (prompt_len, start_pos), run full forward
4. Decode → Pick largest same-position group, single-token forward
```
## Sampling (Strategy Pattern)
```
BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
├── TopPStrategy
└── SamplingPipeline
```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Strategy Pattern)
```python
class ProtocolHandler: # concrete orchestrator
def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops)
else: return await self._handle_non_stream(agen, ctx, stops)
```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult
```
InferenceEngine
├── generate(prompt, stream, ...) → str | List[str] | Generator
├── generate_with_request(req) → same
├── generate_async(prompt, ...) → AsyncGenerator
├── get_stats() → Dict
└── shutdown()
```
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
## HTTP API
```
POST /v1/chat/completions OpenAI
POST /v1/messages Anthropic
GET /health {"status":"ok","model_loaded":true}
GET /stats scheduler statistics
```
### OpenAI
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Response:
```json
{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
}
```
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
### Anthropic
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Supports `stop_sequences` and streaming via `event: content_block_delta`.
### GenerationRequest Parameters
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `top_k` | int | 50 | Top-k count |
| `top_p` | float | 1.0 | Nucleus threshold |
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
| `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output |
### SSE Streaming Format
**OpenAI** (`/v1/chat/completions`, `stream=true`):
```
data: {"id":"chatcmpl-...","object":"chat.completion.chunk","created":...,"model":"astrai",
"choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
data: {"id":"chatcmpl-...","object":"chat.completion.chunk",...,
"choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"chatcmpl-...","object":"chat.completion.chunk",...,
"choices":[{"index":0,"delta":{},"finish_reason":"stop"}],
"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}
data: [DONE]
```
**Anthropic** (`/v1/messages`, `stream=true`):
```
event: message_start
data: {"type":"message_start","message":{"id":"msg_...","model":"astrai","role":"assistant",
"content":[],"stop_reason":null,...}}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{...}}
event: message_stop
data: {"type":"message_stop"}
```
### Error Responses
All endpoints use standard HTTP status codes:
| Status | Meaning |
|--------|---------|
| 200 | Success |
| 400 | Invalid request (bad JSON, missing fields, validation error) |
| 405 | Method not allowed |
| 422 | Unprocessable entity (Pydantic validation) |
| 500 | Internal server error (model crash, OOM, scheduler failure) |
| 503 | Service unavailable (model not loaded, engine not ready) |
Error response body:
```json
{
"error": {
"message": "Invalid request: max_tokens must be > 0",
"type": "invalid_request_error",
"code": 400
}
}
```
### Stats Endpoint
```
GET /stats
```
Response:
```json
{
"active_requests": 3,
"waiting_requests": 2,
"total_requests": 128,
"cache_usage": 0.45,
"tokens_generated": 10240
}
```
`cache_usage` is the fraction of KV cache pages currently in use (0.01.0).
## Engine API
```python
# Non-streaming
engine.generate("Hello", stream=False) # -> str
engine.generate(["A", "B"], stream=False) # -> List[str]
# Streaming
engine.generate("Hello", stream=True) # -> Generator[str]
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
# Async
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
print(token)
```
> Document Update Time: 2026-06-19

View File

@ -1,334 +0,0 @@
## Model Introduction
### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving:
```python
from astrai.model import AutoModel
# Load model from checkpoint
model = AutoModel.from_pretrained("path/to/model")
# Save model to new directory
model.save_pretrained("path/to/save")
```
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types.
```mermaid
flowchart TB
subgraph Layers["Transformer Layers"]
direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer ...]
D --> E[RMSNorm]
E --> F[Linear]
F --> G[SoftMax]
end
subgraph TransformerBlock["Transformer Block"]
direction TB
H[x] --> I[RMSNorm]
I --> J[Linear → Q/K/V]
J --> K[Q]
J --> L[K]
J --> M[V]
K --> N[RoPE]
L --> O[RoPE]
N --> P["Q @ K^T / sqrt(d)"]
O --> P
P --> Q[Masked SoftMax]
Q --> R[S @ V]
M --> R
R --> S[Linear]
S --> T[+]
H --> T
T --> U[RMSNorm]
U --> V["Linear (gate)"]
U --> W["Linear (up)"]
V --> X[SiLU]
X --> Y[×]
W --> Y
Y --> Z["Linear (down)"]
Z --> AA[+]
T --> AA
AA --> BB[x']
end
classDef main fill:#e6f3ff,stroke:#0066cc;
classDef block fill:#fff2e6,stroke:#cc6600;
class Layers main;
class TransformerBlock block;
```
What is an autoregressive model? After splitting a sentence into tokens, the model predicts the probability distribution of the next token. This means the model calculates the probability of the next possible token and its corresponding probability based on the given context (the sequence of tokens that have already appeared).
#### 1. Autoregression
In autoregressive modeling, when a sentence is tokenized into a sequence of tokens, the model learns to predict what comes next. Given a sequence of tokens as input, the model calculates a probability distribution over all possible next tokens. This distribution tells us how likely each potential next token is, given the current context.
For instance, if the input sequence contains tokens representing a question, the model might predict that certain response tokens have higher probabilities than others. The sampling process then selects one token from this distribution—controlled by parameters like top_k, top_p, and temperature—to serve as the next token in the sequence.
Once a token is selected, it is appended to the input sequence, and the model repeats this process. The updated sequence is then fed back into the model to predict the next token. This iterative process continues until either a special end-of-sequence token is generated, or the maximum sequence length is reached. These control tokens are essential because without them, the model would continue generating tokens indefinitely, eventually exhausting available memory.
#### 2. Causal Mask
Transformers use attention mechanism. The input shape is generally [bsz, seq_len], and the output is [bsz, seq_len, n_dim]. To predict the next token, the model's input and output must be offset by one position. The target predicted by the model must be offset by one position, and during training we also use the offset-by-one method:
```
sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
The attention score calculation formula is:
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
$$ s_{ij} := s_{ij} + mask_{ij} $$
Here, the attention score represents the degree to which the model attends to the similarity between two tokens.
For decoder-only structure models, to prevent the model from "stealing" information from future positions, a mask needs to be added during attention calculation. We need to apply a mask before attention score calculation. This mask is typically a lower triangular matrix, and for a sequence of length n, its shape is [n, n]. Below is an example of how to create such a causal mask matrix for a sequence of length 5:
```
[[0, -inf, -inf, -inf, -inf],
[0, 0, -inf, -inf, -inf],
[0, 0, 0, -inf, -inf],
[0, 0, 0, 0, -inf],
[0, 0, 0, 0, 0]]
```
In this matrix, 0 represents positions that can be attended to, while -inf represents positions that should be masked (i.e., should not be attended to). Because this matrix ensures that after the softmax, the parts of the attention scores where $j > i$ change from `inf` to 0, meaning the model cannot see future information.
#### 3. Rotary Position Embedding
Rotary Position Embedding (RoPE) is a position encoding method designed to solve the problem of lacking direct modeling of sequence position information in Transformer models. Unlike traditional position encodings (such as sine and cosine function position encodings), RoPE embeds position information directly into the Query (Q) and Key (K) vectors, allowing the model to more naturally handle relative position relationships in sequences.
$$ q_i = R_i W_q x_i $$
$$ k_j = R_j W_k x_j $$
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
The $R_{i-j}$ controls the attenuation of attention for different tokens at different relative distances. When the absolute value of $i - j$ is larger, the degree of attenuation is stronger. This approach allows the model to learn relative position relationships, enabling the model to scale and adapt to longer sequences.
## KV Cache Implementation
According to the attention calculation formula:
$$
\begin{align*}
o_i &= \sum_j s_{ij} v_{j} \newline
s_{ij} &= \text{softmax}\left( \frac{q_{i} k_{j}}{\sqrt{d_k}} \right)
\end{align*}
$$
Since the model is an autoregressive model, we only need to calculate for the last part of the sequence, meaning the index $i$ is fixed as the last element of the sequence, and we compute $o_{n}$:
$$
\begin{align*}
o_n &= \sum_j s_{j}v_{j} \newline
s_j &= \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}} \right)
\end{align*}
$$
If we expand the expression:
$$
o_n = \sum_j \text{softmax}\left(\frac{q_n k_{j}}{\sqrt{d_k}}\right)v_{j}
$$
In the above expression, only k and v have length indices, while $q$ does not. Therefore, during the calculation process, the input of $q$ is fixed as the last token from the previous input, while $k$ and $v$ need to be cached for parts of different lengths. Also, when caching, note that position encoding calculation should be performed before KV cache computation, otherwise there will be position encoding calculation errors.
### 4. AutoModel Loading
The project now uses the **AutoModel** base class for flexible model loading and saving:
```python
from astrai.model import AutoModel
# Load model from checkpoint
model = AutoModel.from_pretrained("path/to/model")
# Save model to new directory
model.save_pretrained("path/to/save")
```
The Transformer model is registered via `@AutoModel.register('transformer')` decorator, allowing easy extension for new model types. The `from_pretrained` method automatically loads the `config.json` to determine the model type and uses safetensors format for weights.
### 5. Continuous Batching Inference
The inference engine supports **continuous batching** for efficient batch processing:
```python
from astrai.inference import InferenceEngine, GenerationRequest
# Create inference engine with continuous batching
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
# Use GenerationRequest with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_tokens=None,
stream=True,
)
# Generate with streaming
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
```
The continuous batching feature allows dynamic batch composition where new requests can join at any time and completed requests are released immediately.
## HTTP API Usage
The inference server provides HTTP endpoints for remote inference. Start the server first:
```bash
python -m scripts.tools.server --port 8000
```
### OpenAI-Compatible Endpoint
The server provides an OpenAI-compatible chat completion endpoint at `/v1/chat/completions`:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
],
"temperature": 0.8,
"max_tokens": 2048,
"stream": false
}'
```
**Request Parameters:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `messages` | List[dict] | Required | Chat messages with role and content |
| `temperature` | float | 1.0 | Sampling temperature (0.0-2.0) |
| `top_p` | float | 1.0 | Nucleus sampling threshold |
| `top_k` | int | 50 | Top-k sampling parameter |
| `max_tokens` | int | 1024 | Maximum tokens to generate |
| `stream` | bool | false | Enable streaming response |
**Response (non-streaming):**
```json
{
"id": "chatcmpl-1234567890",
"object": "chat.completion",
"created": 1234567890,
"model": "astrai",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "Hello! I'm doing well..."},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
```
### Streaming Response
Enable streaming for real-time token-by-token output:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Write a story"}],
"stream": true,
"max_tokens": 500
}'
```
The server uses Server-Sent Events (SSE) with content type `text/event-stream`.
### Anthropic-Compatible Endpoint
The server also provides an Anthropic-compatible endpoint at `/v1/messages`:
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"max_tokens": 2048
}'
```
Response:
```json
{
"id": "msg_abc123...",
"type": "message",
"role": "assistant",
"model": "astrai",
"content": [{"type": "text", "text": "Hello! I am doing well..."}],
"stop_reason": "end_turn",
"stop_sequence": null,
"usage": {"input_tokens": 20, "output_tokens": 15}
}
```
Streaming:
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Write a short poem"}],
"max_tokens": 500,
"stream": true
}'
```
Supports `stop_sequences` for early termination:
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stop_sequences": ["The end", "THE END"]
}'
```
### Health Check
Monitor server and model status:
```bash
curl http://localhost:8000/health
# {"status": "ok", "model_loaded": true}
curl http://localhost:8000/stats
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
```
> Document Update Time: 2026-05-14

View File

@ -1,4 +1,11 @@
# Parameter Documentation
# CLI Parameter Reference
## Contents
- [Training Parameters](#training-parameters)
- [Inference Server](#inference-server-serverpy)
- [Generate](#generate-generatepy)
- [Preprocess](#preprocess-preprocesspy)
## Training Parameters
@ -10,14 +17,14 @@
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
| `--batch_per_device` | Batch size per device | 1 |
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 |
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
@ -48,111 +55,138 @@
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Validation
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--val_split` | Ratio to split from training dataset for validation (e.g. 0.05) | None |
| `--val_step` | Number of optimizer steps between validation runs | 1000 |
### Logging
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--log_dir` | Directory for metric logs | checkpoint/logs |
| `--log_interval` | Number of batch iterations between metric logs | 100 |
| `--metrics` | Metrics to log (e.g. --metrics loss lr val_loss) | ["loss", "lr"] |
### Gradient Checkpointing
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--gradient_checkpointing` | Enable activation checkpointing for DecoderBlock modules | False |
### Distributed Training
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
| `--backend` | Distributed training backend | nccl |
| `--master_addr` | Master node address | localhost |
| `--master_port` | Master node port | 29500 |
### Strategy-specific
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.0 | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
| `--neftune_alpha` | NEFTune noise alpha (0=disabled, typical: 5.0) | 0.0 | `sft` |
### Usage Example
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
---
## Generation Parameters
## Inference Server (`server.py`)
### GenerationRequest Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `--host` | str | `0.0.0.0` | Host address |
| `--port` | int | `8000` | Port number |
| `--param_path` | path | `project_root/params` | Path to model parameters |
| `--device` | str | `cuda` | Device to load model on |
| `--dtype` | str | `bfloat16` | Model weights dtype (`bfloat16`, `float16`, `float32`) |
| `--max_batch_size` | int | `16` | Maximum batch size for continuous batching |
| `--reload` | flag | `False` | Enable auto-reload for development |
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `messages` | List of message dictionaries (role, content) | required |
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (unlimited) |
| `stream` | Whether to stream output | False |
### Usage Example
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
)
# Build request with messages format
request = GenerationRequest(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
],
temperature=0.8,
top_p=0.95,
top_k=50,
max_tokens=None,
)
# Generate (streaming)
for token in engine.generate_with_request(request):
print(token, end="", flush=True)
# Or use simple generate interface
result = engine.generate(
prompt="Hello",
stream=False,
max_tokens=1024,
temperature=0.8,
top_p=0.95,
top_k=50,
)
Usage:
```bash
python scripts/tools/server.py --param_path ./params --device cuda --dtype bfloat16
```
### Generation Modes
See [Inference Guide](inference.md) for HTTP API documentation.
| Mode | Description |
|------|-------------|
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
## Generate (`generate.py`)
> Document Update Time: 2026-05-14
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `--param_path` | str | required | Path to the model directory |
| `--input_json_file` | str | required | Path to the input JSONL file |
| `--output_json_file` | str | required | Path to the output JSONL file |
| `--question_key` | str | `question` | Key for the question in input JSON |
| `--response_key` | str | `response` | Key for the response in output JSON |
| `--temperature` | float | `0.60` | Sampling temperature |
| `--top_k` | int | `30` | Top-k filtering |
| `--top_p` | float | `0.95` | Nucleus sampling threshold |
| `--batch_size` | int | `1` | Batch size for generation |
| `--max_tokens` | int | `2048` | Maximum tokens to generate |
Usage:
```bash
python scripts/tools/generate.py \
--param_path ./params \
--input_json_file input.jsonl \
--output_json_file output.jsonl
```
## Preprocess (`preprocess.py`)
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `input_files` | path(s) | required | Input JSONL file(s), supports glob (`data/*.jsonl`) |
| `--output_dir`, `-o` | path | required | Output directory for processed data |
| `--config`, `-c` | path | required | Preprocessing pipeline config (JSON) |
| `--num_workers` | int | `4` | Number of parallel workers |
Usage:
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
See [Preprocessing Guide](preprocessing.md) for config file format and examples.
---
> Document Update Time: 2026-06-19

View File

@ -0,0 +1,361 @@
# Preprocessing Pipeline
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
## Contents
- [Philosophy](#philosophy)
- [Config Structure](#config-structure)
- [Quick Start](#quick-start) — SFT Chat, SFT Instruction, Pretrain, DPO, GRPO examples
- [Configuration Reference](#configuration-reference) — all fields
- [Mask Algorithm](#mask-algorithm)
- [Output Layout](#output-layout)
- [CLI](#cli)
- [Python API](#python-api)
## Philosophy
| Component | Responsibility |
|-----------|---------------|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
A single config file captures the entire pipeline, reusable and version-controllable.
## Config Structure
```json
{
"input": {}, // sections (single) or sources (multi)
"mask": {}, // role → "train" | "mask"
"mask_default": "mask",
"preprocessing": {},
"output": {}
}
```
### Section Fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `field` | str | -- | JSONL key to read |
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
| `template` | bool | `false` | Apply `chat_template` per message |
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
### Source Fields (multi-output mode)
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] | -- | Same as single-output section list |
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
---
## Quick Start
### SFT Chat
Input JSONL:
```json
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
```
Config:
```json
{
"input": {
"sections": [
{"field": "messages", "action": "$role", "template": true}
]
},
"mask": {
"system": "mask",
"user": "mask",
"assistant": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
},
"output": {
"storage_format": "bin",
"dtype": {"loss_mask": "bool"}
}
}
```
Output keys: `sequence` (int32), `loss_mask` (bool)
### SFT Instruction
Input JSONL:
```json
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
```
Config:
```json
{
"input": {
"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
}
}
```
Output keys: `sequence`, `loss_mask`
### Pretrain
Input JSONL:
```json
{"text": "Artificial Intelligence is a field of computer science..."}
```
Config:
```json
{
"input": {
"sections": [
{"field": "text", "action": "train"}
]
},
"preprocessing": {
"max_seq_len": 8192,
"min_chars": 100
}
}
```
Output keys: `sequence` (no `loss_mask` — all tokens trained)
### DPO
Input JSONL:
```json
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
```
Config:
```json
{
"input": {
"sources": {
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": true}
]
},
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": true}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
### GRPO
Input JSONL:
```json
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
```
Config:
```json
{
"input": {
"sources": {
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": true}
]
},
"responses": {
"sections": [
{"field": "responses", "action": "train"}
],
"list_field": true,
"mask_key": "masks"
},
"rewards": {
"sections": [
{"field": "rewards", "action": "value"}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
- `action: "value"` — extract raw values from JSONL without tokenisation
- `list_field: true` — tokenise each list element independently, then concatenate
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
---
## Configuration Reference
### `input`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
When `sources` is set, `sections` is ignored.
### `mask`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
### `preprocessing`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
| `max_items` | int or null | `null` | Stop after N documents |
| `packing_strategy` | str | `"simple"` | Packing strategy: `"simple"`, `"bfd"`, `"bfd_split"` |
| `max_packed_len` | int | `8192` | Maximum length of a packed bin |
| `truncation_mode` | str | `"keep_start"` | How to truncate sequences: `"keep_start"` or `"keep_end"` |
### `output`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
| `position_ids_mode` | str | `"none"` | How to compute position_ids: `"none"`, `"doc_reset"`, `"continuous"` |
---
## Mask Algorithm
### Template mode (`template: true`)
For each message in the field's array:
1. Prepend BOS token (masked)
2. Render through `chat_template` for that single message
3. Encode rendered text
4. Apply mask rule for the message's role
### Non-template mode
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
### Text config detection
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
---
## Output Layout
### Single-Shard (`bin`)
```
output/
__default__/
meta.json
sequence.bin
loss_mask.bin
wiki/
meta.json
sequence.bin
loss_mask.bin
```
### Multi-Shard (`bin`)
When `max_tokens_per_shard` is exceeded:
```
output/
__default__/
shard_0000/
meta.json
sequence.bin
loss_mask.bin
shard_0001/
meta.json
sequence.bin
loss_mask.bin
```
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
---
## CLI
```bash
# SFT
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
# DPO
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
# GRPO
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
```
---
## Python API
```python
from astrai.preprocessing.pipeline import Pipeline
from astrai.config.preprocess_config import PipelineConfig
config = PipelineConfig.from_json("sft.json")
Pipeline(
config,
["data_part1.jsonl", "data_part2.jsonl"],
output_dir="output/",
tokenizer_path="params",
).run()
```
> Document Update Time: 2026-06-03

215
assets/docs/training.md Normal file
View File

@ -0,0 +1,215 @@
# Training
## Contents
- [Autoregression](#autoregression)
- [Causal Mask](#causal-mask)
- [Rotary Position Embedding (RoPE)](#rotary-position-embedding-rope)
- [Training Loop](#training-loop)
- [Strategies](#strategies) — SEQ, SFT, DPO, GRPO
- [LR Schedulers](#lr-schedulers)
- [Gradient Checkpointing](#gradient-checkpointing)
- [Checkpoint](#checkpoint)
- [TrainContextBuilder](#traincontextbuilder-builder-pattern)
- [Training CLI](#training-cli)
### Autoregression
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
### Causal Mask
```
sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
Lower-triangular mask prevents attending to future positions:
```
[[0, -inf, -inf, -inf, -inf],
[0, 0, -inf, -inf, -inf],
[0, 0, 0, -inf, -inf],
[0, 0, 0, 0, -inf],
[0, 0, 0, 0, 0]]
```
### Rotary Position Embedding (RoPE)
RoPE embeds position into Q/K vectors via complex rotation:
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
## Training Loop
Two-level loop: **epoch****batch**. Optimizer step fires every `grad_accum_steps` batches.
```
on_train_begin
model.train()
on_epoch_begin
for batch in dataloader:
on_batch_begin
with executor.accumulate(model):
loss = strategy.compute_loss(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
on_batch_end
if executor.sync_gradients:
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
if scheduler:
scheduler.step()
on_epoch_end
on_train_end
```
### Callback Lifecycle
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
| `on_batch_begin` | Every batch | — |
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies
### SEQ (Pre-training)
Next-token cross-entropy with optional label smoothing:
$$
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
### SFT (Supervised Fine-Tuning)
Masked cross-entropy (`ignore_index=-100`) over response tokens:
$$
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
### DPO (Direct Preference Optimization)
Frozen reference model, preference margin via log-ratio:
$$
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
$$
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
### GRPO (Group Relative Policy Optimization)
On-policy PPO with group-normalized advantages:
$$
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
$$
$$
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
$$
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
Keys: `prompts`, `responses`, `masks`, `rewards`.
## LR Schedulers
| Type | Class | Description |
|------|-------|-------------|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
| WSD | `WSDScheduler` | Warmup-Stable-Decay with sqrt cooldown |
Created by `SchedulerFactory.create(schedule_type, optimizer, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`, `"wsd"`. Omit to use no scheduler.
## Gradient Checkpointing
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
```python
from astrai.model.components.decoder_block import DecoderBlock
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
```
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern)
```python
context = (
TrainContextBuilder(config)
.with_resume_dir(resume_dir)
.build()
)
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
```
- Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
## Training CLI
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-30

View File

@ -1,32 +1,100 @@
__version__ = "1.3.5"
__version__ = "1.3.7"
__author__ = "ViperEkura"
from astrai.config import (
ModelConfig,
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
PipelineConfig,
TrainConfig,
)
from astrai.dataset import DatasetFactory
from astrai.dataset import (
BaseDataset,
DatasetFactory,
ResumableDistributedSampler,
Store,
StoreFactory,
)
from astrai.factory import BaseFactory
from astrai.inference import (
GenerationRequest,
InferenceEngine,
ProtocolHandler,
SamplingPipeline,
get_app,
run_server,
sample,
)
from astrai.model import (
AutoModel,
AutoRegressiveLM,
EmbeddingEncoder,
LoRAConfig,
inject_lora,
)
from astrai.parallel import (
ExecutorFactory,
get_rank,
get_world_size,
only_on_rank,
spawn_parallel_fn,
)
from astrai.preprocessing import Pipeline, filter_by_length
from astrai.serialization import Checkpoint
from astrai.tokenize import AutoTokenizer, ChatTemplate
from astrai.trainer import (
BaseScheduler,
BaseStrategy,
CallbackFactory,
Muon,
SchedulerFactory,
StrategyFactory,
TrainCallback,
Trainer,
)
from astrai.model import AutoModel, Transformer
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"Transformer",
"ModelConfig",
"TrainConfig",
"DatasetFactory",
"AutoRegressiveLM",
"AutoRegressiveLMConfig",
"AutoModel",
"AutoTokenizer",
"BaseDataset",
"BaseFactory",
"BaseModelConfig",
"BaseScheduler",
"BaseStrategy",
"CallbackFactory",
"ChatTemplate",
"Checkpoint",
"ConfigFactory",
"DatasetFactory",
"EmbeddingEncoder",
"EncoderConfig",
"ExecutorFactory",
"GenerationRequest",
"InferenceEngine",
"Trainer",
"CallbackFactory",
"StrategyFactory",
"LoRAConfig",
"Muon",
"Pipeline",
"PipelineConfig",
"ProtocolHandler",
"ResumableDistributedSampler",
"SamplingPipeline",
"SchedulerFactory",
"BaseFactory",
"AutoModel",
"Store",
"StoreFactory",
"StrategyFactory",
"TrainCallback",
"TrainConfig",
"Trainer",
"filter_by_length",
"get_app",
"get_rank",
"get_world_size",
"inject_lora",
"only_on_rank",
"run_server",
"sample",
"spawn_parallel_fn",
]

View File

@ -1,8 +1,25 @@
from astrai.config.model_config import ModelConfig
from astrai.config.model_config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
)
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.config.train_config import TrainConfig
__all__ = [
# Model configuration
"ModelConfig",
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ConfigFactory",
"TrainConfig",
"InputConfig",
"OutputConfig",
"PipelineConfig",
"ProcessingConfig",
]

98
astrai/config/base.py Normal file
View File

@ -0,0 +1,98 @@
import json
from dataclasses import MISSING, dataclass, fields
from pathlib import Path
from typing import Any, Dict, Optional, Self, Union, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, (dict, list, tuple)):
try:
val = list(v) if isinstance(v, tuple) else v
json.dumps(val)
d[fld.name] = val
except (TypeError, ValueError):
pass
elif isinstance(v, BaseConfig):
d[fld.name] = v.to_dict()
elif hasattr(v, "__dataclass_fields__"):
sub = {}
for f in fields(v):
a = getattr(v, f.name)
sub[f.name] = list(a) if isinstance(a, tuple) else a
d[fld.name] = sub
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
return target_type.from_dict(value)
raise TypeError
@classmethod
def from_file(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_file(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -1,42 +1,79 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional, Self
from dataclasses import dataclass
from typing import Any, Dict, Optional
from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
class ConfigFactory(BaseFactory[BaseConfig]):
"""Factory that dispatches config classes by ``model_type``."""
@classmethod
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
model_type = raw.get("model_type") or "autoregressive_lm"
config_cls = cls.get_component_class(model_type)
return config_cls.from_dict(raw)
@dataclass
class ModelConfig:
# basic config
class BaseModelConfig(BaseConfig):
"""Base config with ``model_type`` dispatch and file I/O."""
model_type: Optional[str] = None
@dataclass
@ConfigFactory.register("autoregressive_lm")
class AutoRegressiveLMConfig(BaseModelConfig):
"""Configuration for autoregressive language model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
# GQA
attn_type: str = "gqa"
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self:
config = {}
with open(config_path, "r") as f:
config.update(json.load(f))
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None
return self
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@dataclass
@ConfigFactory.register("embedding")
class EncoderConfig(BaseModelConfig):
"""Configuration for embedding encoder model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
pooling_type: Optional[str] = None
normalize_embeddings: Optional[bool] = None

View File

@ -0,0 +1,109 @@
"""Pipeline configuration for JSONL preprocessing.
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
modes, both driven declaratively through ``input.sections`` or
``input.sources``.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
"""Declarative input mapping.
Single-output mode (backward-compatible)::
{"input": {"sections": [{"field": "messages", ...}]}}
Multi-output mode (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [{"field": "chosen", ...}]},
"rejected": {"sections": [{"field": "rejected", ...}]},
}}}
"""
sections: Optional[List[Dict]] = None
sources: Optional[Dict[str, Dict]] = None
@dataclass
class ProcessingConfig(BaseConfig):
"""Processing configuration.
Parameters
----------
max_seq_len : int
Maximum sequence length (default: 2048).
min_chars : int
Minimum number of characters to keep (default: 50).
max_chars : int
Maximum number of characters to keep (default: 2_000_000).
max_items : Optional[int]
Maximum number of items to process (default: None, unlimited).
packing_strategy : str
How to pack sequences into a contiguous stream.
- ``"simple"``: sequential concatenation (default, backward compatible).
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
max_packed_len : int
Maximum length of a packed bin. Sequences longer than this are
truncated or split depending on ``packing_strategy`` (default: 8192).
truncation_mode : str
How to truncate sequences longer than ``max_packed_len``.
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
"""
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
max_items: Optional[int] = None
packing_strategy: str = "simple"
max_packed_len: int = 8192
truncation_mode: str = "keep_start"
@dataclass
class OutputConfig(BaseConfig):
"""Output configuration.
Parameters
----------
domain_key : Optional[str]
Domain key for the output store (default: None).
storage_format : str
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
max_tokens_per_shard : int
Maximum tokens per shard before splitting (default: 100_000_000).
dtype : Dict[str, str]
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
position_ids_mode : Optional[str]
How to compute position_ids in packed sequences.
- ``"none"``: do not generate (default).
- ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
"""
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
position_ids_mode: str = "none"
@dataclass
class PipelineConfig(BaseConfig):
version: int = 1
input: InputConfig = field(default_factory=InputConfig)
mask: Dict[str, str] = field(default_factory=dict)
mask_default: str = "mask"
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
output: OutputConfig = field(default_factory=OutputConfig)

View File

@ -1,32 +1,49 @@
from dataclasses import dataclass, field
from typing import Callable, Optional
from dataclasses import dataclass, field, fields
from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
from astrai.model.components.lora import LoRAConfig
def required(**kw):
return {"required": True, **kw}
@dataclass
class TrainConfig:
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model factory for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(
default=None, metadata=required(help="Dataset for training.")
)
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata={"help": "Optimizer factory for training."}
default=None, metadata=required(help="Optimizer factory for training.")
)
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata={"help": "Scheduler factory for training."}
default=None, metadata=required(help="Scheduler factory for training.")
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field(
batch_per_device: int = field(
default=4, metadata={"help": "Batch size per device."}
)
grad_accum_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."}
)
gradient_checkpointing_modules: List[str] = field(
default_factory=list,
metadata={"help": "Module types to enable activation checkpointing for."},
)
# checkpoint setting
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
@ -40,6 +57,25 @@ class TrainConfig:
default=5000, metadata={"help": "Number of iterations between checkpoints."}
)
# lora setting
lora: Optional[LoRAConfig] = field(
default=None,
metadata={"help": "LoRA config. None means full fine-tuning."},
)
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field(
@ -66,18 +102,42 @@ class TrainConfig:
master_port: str = field(
default="29500", metadata={"help": "Master port for distributed training."}
)
parallel_wrapper: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for training."}
parallel_mode: str = field(
default="none",
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
)
state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."}
start_method: str = field(
default="spawn",
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
)
# others
device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."}
)
extra_kwargs: dict = field(
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_split: Optional[float] = field(
default=None,
metadata={
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
},
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
neftune_alpha: float = field(
default=0.0,
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
)
executor_kwargs: Dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
)
extra_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Other arguments."}
)
@ -85,14 +145,6 @@ class TrainConfig:
self.validate()
def validate(self):
required_fields = [
"model",
"strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")
for fld in fields(self):
if fld.metadata.get("required") and getattr(self, fld.name) is None:
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")

View File

@ -4,34 +4,28 @@ from astrai.dataset.dataset import (
)
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MultiSegmentFetcher,
available_storage_types,
create_storage,
H5Store,
MmapStore,
Store,
StoreFactory,
detect_format,
load_bin,
load_h5,
load_json,
save_bin,
save_h5,
save_json,
)
__all__ = [
"BaseDataset",
"DatasetFactory",
"BaseSegmentFetcher",
"MultiSegmentFetcher",
"BaseStorage",
"H5Storage",
"JSONStorage",
"create_storage",
"Store",
"StoreFactory",
"H5Store",
"MmapStore",
"detect_format",
"available_storage_types",
"save_h5",
"load_h5",
"save_json",
"load_json",
"save_bin",
"load_bin",
"ResumableDistributedSampler",
]

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
create_storage,
Store,
StoreFactory,
detect_format,
)
from astrai.factory import BaseFactory
@ -26,33 +26,47 @@ class BaseDataset(Dataset, ABC):
super().__init__()
self.window_size = window_size
self.stride = stride
self.storage: Optional[BaseStorage] = None
self.storage: Optional[Store] = None
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
@property
def required_keys(self) -> List[str]:
"""Return required storage keys for this dataset type.
Subclasses should override to specify expected keys.
"""
return []
def _validate_keys(self):
if not self.required_keys:
return
actual_keys = set(self.storage.keys)
missing = [k for k in self.required_keys if k not in actual_keys]
if missing:
raise KeyError(
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
f"Missing: {missing}"
)
def load(self, load_path: str, storage_type: Optional[str] = None):
"""Load dataset from the given path.
Auto-detects the storage format if not specified.
Args:
load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "json"),
storage_type: Force a specific storage type ("h5", "bin"),
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
Raises:
KeyError: If the loaded storage is missing required keys.
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = create_storage(storage_type)
self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
self.storage = StoreFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path)
self._validate_keys()
@property
def count(self) -> int:
@ -122,26 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
dataset = DatasetFactory.create("custom", window_size, stride)
"""
@classmethod
def _validate_component(cls, dataset_cls: type) -> None:
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling
stride: Stride between consecutive samples
Returns:
Dataset instance
"""
return super().create(train_type, window_size, stride)
@classmethod
def load(
cls,
@ -150,7 +144,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
window_size: int,
stride: Optional[int] = None,
storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset":
"""Create and load a dataset in one step.
@ -159,8 +152,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "json") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
storage_type: Storage type ("h5", "bin") or None for auto-detection
Returns:
Loaded dataset instance
@ -169,22 +161,18 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
dataset.load(load_path, storage_type=storage_type)
return dataset
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
return cls.list_registered()
@DatasetFactory.register("seq")
class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence"]
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, "sequence")
@ -202,8 +190,9 @@ class SEQDataset(BaseDataset):
class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence", "loss_mask", "position_ids"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -211,23 +200,26 @@ class SFTDataset(BaseDataset):
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
dtype=torch.long
)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
x = self._fetch_data(begin_idx, end_idx, "sequence")
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence")
position_ids = self._fetch_data(begin_idx, end_idx, "position_ids")
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask")
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
return {
"input_ids": x.to(dtype=torch.long),
"target_ids": y.to(dtype=torch.long),
"position_ids": position_ids.to(dtype=torch.long),
"loss_mask": loss_mask.to(dtype=torch.bool),
}
@DatasetFactory.register("dpo")
class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -256,8 +248,9 @@ class DPODataset(BaseDataset):
class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["prompts", "responses", "masks", "rewards"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -265,9 +258,11 @@ class GRPODataset(BaseDataset):
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index)
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
responses = self._fetch_data(begin_idx, end_idx, "responses")
masks = self._fetch_data(begin_idx, end_idx, "masks")
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
dtype=torch.long
)
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {

View File

@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]):
offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas
self.iter = self.iter % self.num_samples_per_replica
self._indices = None
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
self.epoch += 1
self._indices = None
@property
def _remaining(self):
remaining = self.num_samples_per_replica - self.iter
return max(remaining, 0)
def __len__(self):
return self.num_samples_per_replica
return self._remaining

View File

@ -1,20 +1,37 @@
"""Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
"""
import bisect
import glob
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from typing import Dict, List, Union
import h5py
import numpy as np
import torch
from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
@ -52,54 +69,30 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
meta = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
with open(os.path.join(file_path, "meta.json"), "w") as f:
json.dump(meta, f)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
with open(os.path.join(file_path, "meta.json"), "r") as f:
meta = json.load(f)
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r+",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def detect_format(load_path: str) -> str:
@ -109,7 +102,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path
Returns:
Format string ("h5" or "json")
Format string ("h5" or "bin")
Raises:
FileNotFoundError: If no supported data files are found
@ -119,194 +112,155 @@ def detect_format(load_path: str) -> str:
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
h5_files = [
Path(p)
for pattern in ("*.h5", "*.hdf5")
for p in glob.glob(str(root / "**" / pattern), recursive=True)
]
if h5_files:
return "h5"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
if bin_files:
has_meta = (root / "meta.json").exists() or len(
[Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
) > 0
if has_meta:
return "bin"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
"""
def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod
def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
def load(self, path: str) -> None:
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = (
min((cum[-1] if cum else 0) for cum in self._cum.values())
if self._cum
else 0
)
class H5Storage(BaseStorage):
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
...
"""
@StoreFactory.register("h5")
class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
def load(self, path: str):
self._normalize(load_h5(path))
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())
def load(self, path: str):
self._mmap_refs = []
root = Path(path)
all_raw: Dict[str, List[Tensor]] = {}
meta_paths = [
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
]
for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items():
if key not in all_raw:
all_raw[key] = []
all_raw[key].extend(tensors)
if not meta_paths:
raise FileNotFoundError(f"No meta.json found under {path}")
self._normalize(all_raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)

View File

@ -1,210 +1,145 @@
"""Base factory class for extensible component registration."""
"""Base factory with decorator-based registration and kwarg-filtered instantiation."""
import inspect
import sys
from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from typing import (
Any,
Callable,
Dict,
ForwardRef,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
)
from typing import get_args as _get_args
from typing import get_origin as _get_origin
T = TypeVar("T")
class Registry:
"""Flexible registry for component classes with category and priority support.
def _resolve_type(
arg: Union[Type, str, ForwardRef], factory_cls: type
) -> Optional[Type]:
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
if not isinstance(arg, (str, ForwardRef)):
return arg
This registry stores component classes with optional metadata (category, priority).
It provides methods for registration, retrieval, and listing with filtering.
"""
name = arg if isinstance(arg, str) else arg.__forward_arg__
if name == factory_cls.__name__:
return factory_cls
def __init__(self):
self._entries = {} # name -> (component_cls, category, priority)
mod = sys.modules.get(factory_cls.__module__)
if mod is None:
return None
ns = vars(mod)
def register(
self,
name: str,
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
) -> None:
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
self._entries[name] = (component_cls, category, priority)
if isinstance(arg, ForwardRef):
return arg._evaluate(ns, None, frozenset(), recursive_guard=frozenset())
def get(self, name: str) -> Type:
"""Get component class by name."""
if name not in self._entries:
raise KeyError(f"Component '{name}' not found in registry")
return self._entries[name][0]
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
"""Get component class with its metadata."""
entry = self._entries.get(name)
if entry is None:
raise KeyError(f"Component '{name}' not found in registry")
return entry
def contains(self, name: str) -> bool:
"""Check if a name is registered."""
return name in self._entries
def list_names(self) -> List[str]:
"""Return list of registered component names."""
return sorted(self._entries.keys())
def list_by_category(self, category: str) -> List[str]:
"""Return names of components belonging to a specific category."""
return sorted(
name for name, (_, cat, _) in self._entries.items() if cat == category
)
def list_by_priority(self, reverse: bool = False) -> List[str]:
"""Return names sorted by priority (default ascending)."""
return sorted(
self._entries.keys(),
key=lambda name: self._entries[name][2],
reverse=reverse,
)
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
"""Return raw entries dictionary."""
return self._entries.copy()
return ns.get(name)
class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation.
"""Generic factory with decorator-based component registration.
This base class provides a decorator-based registration pattern
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
class MyFactory(BaseFactory[MyBase]):
pass
@MyFactory.register("custom")
class CustomComponent(MyBaseClass):
class CustomComponent(MyBase):
...
component = MyFactory.create("custom", *args, **kwargs)
obj = MyFactory.create("custom", *args, **kwargs)
``create()`` filters kwargs to match the component's ``__init__``
signature so components don't need ``**kwargs`` just to absorb
unrelated parameters.
"""
_registry: Registry
_entries: Dict[str, Type[T]]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._registry = Registry()
for orig_base in getattr(cls, "__orig_bases__", ()):
if _get_origin(orig_base) is BaseFactory:
(arg,) = _get_args(orig_base)
cls._entries = {}
cls._component_base = _resolve_type(arg, cls)
return
@classmethod
def register(
cls, name: str, category: Optional[str] = None, priority: int = 0
) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class with optional category and priority.
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class.
Args:
name: Registration name for the component
category: Optional category for grouping components
priority: Priority for ordering (default 0)
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
Validates that the decorated class inherits from the generic
type parameter ``T`` declared on the factory.
"""
def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls)
cls._registry.register(
name, component_cls, category=category, priority=priority
)
if name in cls._entries:
raise ValueError(f"Component '{name}' is already registered")
cls._entries[name] = component_cls
return component_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
"""Create a component instance by name, filtering kwargs to match
the component's ``__init__`` signature.
"""
if not cls._registry.contains(name):
entry = cls._entries.get(name)
if entry is None:
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
)
component_cls = cls._registry.get(name)
component_cls = entry
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]) -> None:
"""Validate that the component class is valid for this factory.
def _validate_component(cls, component_cls: Type[T]):
"""Validate the decorated class inherits from the factory's base type.
Override this method in subclasses to add custom validation.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
Override for custom validation beyond ``issubclass``.
"""
pass
base = cls._component_base
if base is not None and not issubclass(component_cls, base):
raise TypeError(
f"{component_cls.__name__} must inherit from {base.__name__}"
)
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
"""Get the registered component class without instantiating it."""
entry = cls._entries.get(name)
if entry is None:
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
)
return cls._registry.get(name)
return entry
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.
Returns:
List of registered component names
"""
return cls._registry.list_names()
def list_registered(cls) -> List[str]:
"""List all registered component names."""
return sorted(cls._entries)
@classmethod
def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered.
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return cls._registry.contains(name)
@classmethod
def list_by_category(cls, category: str) -> List[str]:
"""List registered component names in a category."""
return cls._registry.list_by_category(category)
@classmethod
def list_by_priority(cls, reverse: bool = False) -> List[str]:
"""List registered component names sorted by priority."""
return cls._registry.list_by_priority(reverse)
__all__ = ["Registry", "BaseFactory"]
"""Check if a component name is registered."""
return name in cls._entries

View File

@ -1,25 +1,32 @@
"""Inference module for continuous batching.
Layers:
- core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP protocol handlers (OpenAI, Anthropic)
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP orchestration (ProtocolHandler, server)
- protocols/: Response builders (OpenAI, Anthropic)
- transport/: SSE transport utilities
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
"""
from astrai.inference.api import (
AnthropicHandler,
AnthropicMessage,
BaseToolParser,
ChatCompletionRequest,
ChatMessage,
FunctionDef,
GenContext,
MessagesRequest,
OpenAIHandler,
ProtocolHandler,
SimpleJsonToolParser,
StopChecker,
StreamContext,
app,
ToolDef,
ToolParserFactory,
get_app,
run_server,
)
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.core import (
STOP,
Allocator,
@ -36,10 +43,7 @@ from astrai.inference.core import (
TaskTable,
page_hash,
)
from astrai.inference.engine import (
GenerationRequest,
InferenceEngine,
)
from astrai.inference.engine import GenerationRequest, InferenceEngine
from astrai.inference.sample import (
BaseSamplingStrategy,
SamplingPipeline,
@ -50,17 +54,14 @@ from astrai.inference.sample import (
)
__all__ = [
# Engine / Requests
"InferenceEngine",
"GenerationRequest",
# Core scheduler
"InferenceScheduler",
"Executor",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
# Core cache
"Allocator",
"KVCache",
"KvcacheView",
@ -69,24 +70,26 @@ __all__ = [
"Storage",
"TaskTable",
"page_hash",
# Sampling (Strategy pattern)
"sample",
"BaseSamplingStrategy",
"TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
# Protocol
"ProtocolHandler",
"StopChecker",
"StreamContext",
"AnthropicHandler",
"OpenAIHandler",
# Server
"GenContext",
"BaseToolParser",
"SimpleJsonToolParser",
"ToolParserFactory",
"OpenAIResponseBuilder",
"AnthropicResponseBuilder",
"ChatMessage",
"ChatCompletionRequest",
"FunctionDef",
"ToolDef",
"AnthropicMessage",
"MessagesRequest",
"app",
"get_app",
"run_server",
]

View File

@ -1,31 +1,39 @@
"""Inference API: protocol handlers and FastAPI server."""
"""Inference API: protocol handler, stop checker, tool parsers, and FastAPI server.
from astrai.inference.api.protocol import (
AnthropicHandler,
OpenAIHandler,
ProtocolHandler,
StopChecker,
StreamContext,
)
``app`` is no longer a module-level global. Use :func:`get_app` to access the
lazy singleton FastAPI instance.
"""
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
from astrai.inference.api.server import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
FunctionDef,
MessagesRequest,
app,
ToolDef,
get_app,
run_server,
)
from astrai.inference.api.tool_parser import (
BaseToolParser,
SimpleJsonToolParser,
ToolParserFactory,
)
__all__ = [
"AnthropicHandler",
"OpenAIHandler",
"ProtocolHandler",
"StopChecker",
"StreamContext",
"GenContext",
"BaseToolParser",
"SimpleJsonToolParser",
"ToolParserFactory",
"AnthropicMessage",
"ChatCompletionRequest",
"ChatMessage",
"FunctionDef",
"ToolDef",
"MessagesRequest",
"app",
"get_app",
"run_server",
]

View File

@ -0,0 +1,142 @@
"""Anthropic message completion response builder."""
import time
import uuid
from typing import Any, Dict, List, Tuple, Union
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class AnthropicResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages: List[Dict[str, str]] = []
system = getattr(request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in request.messages:
text = _extract_text(m.content)
if text:
messages.append({"role": m.role, "content": text})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
ctx = GenContext(
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=int(time.time()),
model=request.model,
)
stop_sequences = getattr(request, "stop_sequences", None) or []
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_chunk(self, token: str, **kwargs) -> List[str]:
return [
sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
]
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
events: List[str] = []
if stop.matched:
trimmed = stop.body[: stop.body.rfind(stop.matched)]
unyielded = trimmed[len(stop.yielded) :]
if unyielded:
events.append(
sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": unyielded},
},
event="content_block_delta",
)
)
events.append(
sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
if stop.matched:
content = content[: content.rfind(stop.matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -0,0 +1,278 @@
"""OpenAI chat completion response builder."""
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.api.tool_parser import BaseToolParser, ToolParserFactory
from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__)
_UNSUPPORTED_PARAMS = (
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
)
def _resolve_tool_choice(
request: BaseModel,
) -> Union[str, Dict[str, Any]]:
tc = getattr(request, "tool_choice", None)
if tc is None:
return "auto"
if isinstance(tc, str):
return tc
if isinstance(tc, dict):
return tc
return "auto"
def _resolve_tools(request: BaseModel) -> Optional[List[Dict[str, Any]]]:
raw = getattr(request, "tools", None)
if not raw:
return None
if isinstance(raw, list):
return [t.model_dump() if hasattr(t, "model_dump") else t for t in raw]
return None
class OpenAIResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages = [{"role": m.role, "content": m.content} for m in request.messages]
tools = _resolve_tools(request)
prompt = engine.tokenizer.apply_chat_template(
messages, tokenize=False, tools=tools or []
)
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model
for param in _UNSUPPORTED_PARAMS:
value = getattr(request, param, None)
fields = getattr(type(request), "model_fields", {})
default = fields[param].default if param in fields else None
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported"
" and will be ignored",
param,
value,
)
self._parser: Optional[BaseToolParser] = None
if tools:
tool_choice = _resolve_tool_choice(request)
self._parser = ToolParserFactory.create(
"simple_json", tools=tools, tool_choice=tool_choice
)
self._content_started = False
ctx = GenContext(
resp_id=self._resp_id,
created=int(time.time()),
model=self._model,
)
stop = request.stop
stop_sequences = (
[] if stop is None else [stop] if isinstance(stop, str) else stop
)
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_chunk(self, token: str, **kwargs) -> List[str]:
body = kwargs.get("body", "")
if self._parser is not None:
return self._format_tool_chunk(body, **kwargs)
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"content": token},
"finish_reason": None,
}
],
}
)
]
def _format_tool_chunk(self, body: str, **kwargs) -> List[str]:
deltas = self._parser.feed(
body,
current_token_ids=kwargs.get("current_token_ids"),
delta_token_ids=kwargs.get("delta_token_ids"),
)
events: List[str] = []
for d in deltas:
if "content" in d:
if not self._content_started:
events.append(self._role_chunk())
self._content_started = True
events.append(
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"content": d["content"]},
"finish_reason": None,
}
],
}
)
)
elif "tool_calls" in d:
if not self._content_started:
events.append(self._role_chunk())
self._content_started = True
events.append(
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"tool_calls": d["tool_calls"]},
"finish_reason": None,
}
],
}
)
)
return events
def _role_chunk(self) -> str:
return sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
finish_reason = "stop"
if self._parser is not None and self._parser.has_tool_calls:
finish_reason = "tool_calls"
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{"index": 0, "delta": {}, "finish_reason": finish_reason}
],
}
),
sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
if self._parser is not None:
parsed = self._parser.parse_complete(content)
if parsed and parsed.get("tool_calls"):
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": parsed.get("content"),
"tool_calls": parsed["tool_calls"],
},
"finish_reason": "tool_calls",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}

View File

@ -1,15 +1,13 @@
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
Template Method + Builder patterns eliminate the 45% code duplication between
stream/non-stream branches and across protocol adapters.
ProtocolHandler orchestrates the async generation loop and delegates
protocol-specific formatting to a ResponseBuilder.
"""
import json
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
@ -17,7 +15,7 @@ from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
@ -26,22 +24,28 @@ def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
return "\n".join(lines)
def _sse_done() -> str:
def sse_done() -> str:
return "data: [DONE]\n\n"
@dataclass
class StreamContext:
"""Shared state across the streaming generation lifecycle."""
class GenContext:
"""Per-generation metadata passed to builder format methods."""
resp_id: str
created: int
model: str
prompt_tokens: int
prompt_tokens: int = 0
completion_tokens: int = 0
accumulated: str = ""
stop_matched: Optional[str] = None
last_yield_trimmed: str = ""
@dataclass
class StopInfo:
"""Stop-check result passed to format_stream_end / format_response."""
matched: Optional[str] = None
body: str = ""
yielded: str = ""
class StopChecker:
@ -56,95 +60,67 @@ class StopChecker:
return seq
return None
def trim(self, text: str, matched: str) -> str:
idx = text.rfind(matched)
return text[:idx] if idx != -1 else text
@property
def has_sequences(self) -> bool:
return len(self._sequences) > 0
class ResponseBuilder(ABC):
"""Interface for protocol-specific response formatting.
class ProtocolHandler(ABC):
"""Template-method base for API protocol handlers.
Subclasses implement format hooks; the base class orchestrates the
generate-async loop and SSE/JSON response construction.
Lifecycle::
handle()
build_prompt() # protocol-specific prompt assembly
create_response_id() # unique response identifier
[stream]
format_stream_start()
format_stream_token() × N
on_token() hook for stop-sequence interception
format_stream_end()
[non-stream]
(accumulate tokens)
format_non_stream_response()
A new protocol requires one concrete builder implementing 5 methods.
"""
request_model: type[BaseModel]
@abstractmethod
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
"""Return (prompt, ctx, stop_sequences) for a generation request."""
def __init__(self, request: BaseModel, engine: InferenceEngine):
@abstractmethod
def format_stream_start(self, ctx: GenContext) -> List[str]:
"""SSE events that open the stream."""
@abstractmethod
def format_chunk(self, token: str, **kwargs) -> List[str]:
"""SSE events for a single generated token.
``body`` (the full accumulated text so far) is always provided
as a keyword argument. Additional keyword arguments such as
``current_token_ids`` and ``delta_token_ids`` may be included
for tool parsers that need token-level information.
Returns a list of SSE event strings (may be empty).
"""
@abstractmethod
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
"""SSE events that close the stream."""
@abstractmethod
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
"""JSON response body for non-streaming mode."""
class ProtocolHandler:
"""Orchestrates the generation loop, delegates formatting to a builder.
Usage::
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
response = await handler.handle()
"""
def __init__(
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
):
self.request = request
self.engine = engine
@abstractmethod
def build_prompt(self) -> str:
"""Build the full prompt string from the request messages."""
@abstractmethod
def create_response_id(self) -> str:
"""Generate a unique response ID following the protocol convention."""
@abstractmethod
def format_stream_start(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that open the stream (role marker, metadata)."""
@abstractmethod
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
"""Yield an SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that close the stream (finish reason, usage stats)."""
@abstractmethod
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
"""Build the JSON response body for non-streaming mode."""
def get_stop_sequences(self) -> List[str]:
return []
def create_stop_checker(self) -> StopChecker:
return StopChecker(self.get_stop_sequences())
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
"""Hook after each token is appended to accumulated.
Return a matched stop-sequence string to break the loop,
or None to continue.
"""
return None
self.builder = builder
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
ctx = StreamContext(
resp_id=self.create_response_id(),
created=int(time.time()),
model=self.request.model,
prompt_tokens=self._count_prompt_tokens(),
)
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
agen = self.engine.generate_async(
prompt=self.build_prompt(),
prompt=prompt,
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
@ -152,33 +128,47 @@ class ProtocolHandler(ABC):
)
if self.request.stream:
return self._handle_stream(agen, ctx)
return self._handle_stream(agen, ctx, stop_sequences)
else:
return await self._handle_non_stream(agen, ctx)
return await self._handle_non_stream(agen, ctx, stop_sequences)
def _count_prompt_tokens(self) -> int:
return len(self.engine.tokenizer.encode(self.build_prompt()))
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
stop_checker = self.create_stop_checker()
def _handle_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> StreamingResponse:
checker = StopChecker(stop_sequences)
async def event_stream():
for event in self.format_stream_start(ctx):
for event in self.builder.format_stream_start(ctx):
yield event
body = ""
yielded = ""
matched = None
token_ids: List[int] = []
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
body += token
matched = self.on_token(ctx, token, stop_checker)
new_ids = self.engine.tokenizer.encode(token)
token_ids.extend(new_ids)
matched = checker.check(body)
if matched:
break
yield self.format_stream_token(ctx, token)
ctx.completion_tokens += 1
for event in self.builder.format_chunk(
token,
body=body,
current_token_ids=token_ids,
delta_token_ids=new_ids,
):
yield event
yielded += token
for event in self.format_stream_end(ctx):
stop = StopInfo(matched=matched, body=body, yielded=yielded)
for event in self.builder.format_stream_end(ctx, stop):
yield event
yield _sse_done()
yield sse_done()
return StreamingResponse(
event_stream(),
@ -186,249 +176,24 @@ class ProtocolHandler(ABC):
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
stop_checker = self.create_stop_checker()
async def _handle_non_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> Dict[str, Any]:
checker = StopChecker(stop_sequences)
chunks: List[str] = []
body = ""
matched = None
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
chunks.append(token)
body += token
matched = self.on_token(ctx, token, stop_checker)
matched = checker.check(body)
if matched:
break
ctx.completion_tokens += 1
content = "".join(chunks)
return self.format_non_stream_response(ctx, content)
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract plain text from an Anthropic content block (string or list)."""
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class OpenAIHandler(ProtocolHandler):
"""OpenAI-compatible /v1/chat/completions handler."""
def build_prompt(self) -> str:
messages = [
{"role": m.role, "content": m.content} for m in self.request.messages
]
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return _sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
_sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
return {
"id": ctx.resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
class AnthropicHandler(ProtocolHandler):
"""Anthropic-compatible /v1/messages handler."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._yielded = ""
def build_prompt(self) -> str:
messages: List[Dict[str, str]] = []
system = getattr(self.request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in self.request.messages:
content = _extract_text_content(m.content)
if content:
messages.append({"role": m.role, "content": content})
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"msg_{uuid.uuid4().hex[:24]}"
def get_stop_sequences(self) -> List[str]:
return getattr(self.request, "stop_sequences", None) or []
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
matched = stop_checker.check(ctx.accumulated)
if not matched:
return None
ctx.stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :]
if unyielded:
ctx.last_yield_trimmed = unyielded
return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
_sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token
return _sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = ctx.stop_matched
events: List[str] = []
last_yielded = ctx.last_yield_trimmed
if last_yielded:
events.append(
_sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": last_yielded},
},
event="content_block_delta",
)
)
events.append(
_sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
_sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
matched = ctx.stop_matched
if matched:
content = content[: content.rfind(matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}
stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop)

View File

@ -3,6 +3,9 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
``app`` is lazily constructed importing this module does NOT create a FastAPI instance.
Use :func:`get_app` to access the singleton.
"""
import logging
@ -12,22 +15,37 @@ from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi import APIRouter, FastAPI, HTTPException
from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import ProtocolHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
_app_instance: Optional[FastAPI] = None
class ChatMessage(BaseModel):
role: str
content: str
content: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
class FunctionDef(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class ToolDef(BaseModel):
type: str = "function"
function: FunctionDef
class ChatCompletionRequest(BaseModel):
@ -46,6 +64,8 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
tools: Optional[List[ToolDef]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
class AnthropicMessage(BaseModel):
@ -67,14 +87,30 @@ class MessagesRequest(BaseModel):
stop_sequences: Optional[List[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
router = APIRouter()
def _create_engine(
param_path: Optional[Path] = None,
param_path: Path,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
@ -92,67 +128,66 @@ def _create_engine(
return engine
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
def get_app() -> FastAPI:
"""Return the singleton FastAPI instance (lazily created on first call)."""
global _app_instance
if _app_instance is None:
_app_instance = FastAPI(
title="AstrAI Inference Server",
version="0.2.0",
lifespan=lifespan,
)
_app_instance.include_router(router)
_app_instance.state.server_config = {}
_app_instance.state.engine = None
return _app_instance
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
def _get_engine() -> InferenceEngine:
engine = get_app().state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
async def health(request: Request):
@router.get("/health")
async def health():
app = get_app()
return {
"status": "ok",
"model_loaded": request.app.state.engine is not None,
"model_loaded": app.state.engine is not None,
}
@app.get("/stats")
async def get_stats(request: Request):
return _get_engine(request).get_stats()
@router.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
handler = OpenAIHandler(request, engine)
@router.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
return await handler.handle()
@app.post("/v1/messages")
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
handler = AnthropicHandler(request, engine)
@router.post("/v1/messages")
async def create_message(request: MessagesRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
return await handler.handle()
def run_server(
param_path: Path,
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app = get_app()
app.state.server_config = {
"device": device,
"dtype": dtype,
@ -163,4 +198,5 @@ def run_server(
app,
host=host,
port=port,
reload=reload,
)

View File

@ -0,0 +1,325 @@
"""Tool call parsers for extracting structured tool calls from model output.
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
detect and incrementally extract tool calls from raw generated text.
Subclasses may optionally consume ``token_ids`` for token-level parsing
(e.g. Harmony / VLM-style parsers).
"""
import re
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from astrai.factory import BaseFactory
class BaseToolParser(ABC):
"""Abstract tool call parser — one instance per request.
Maintains streaming state internally so that each call to :meth:`feed`
can diff against previously emitted content.
Parameters
----------
tools : list of dict, optional
Tool definitions from the request.
tool_choice : str
``"auto"`` / ``"required"`` / ``"none"`` or a named tool choice
dict.
"""
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
self.tools = tools or []
self.tool_choice = tool_choice
@abstractmethod
def feed(
self,
body: str,
current_token_ids: Optional[List[int]] = None,
delta_token_ids: Optional[List[int]] = None,
) -> List[Dict]:
"""Feed the *full* accumulated text each step.
Returns a list of delta dicts to emit. Each delta is one of:
- ``{"content": "text"}`` plain text delta
- ``{"tool_calls": [...]}`` tool-call delta (OpenAI format)
Returns an empty list when nothing new should be emitted.
Parameters
----------
body : str
The complete accumulated generated text so far.
current_token_ids : list of int, optional
All token IDs decoded into *body* (cumulative).
delta_token_ids : list of int, optional
Only the token IDs for this chunk.
"""
@abstractmethod
def parse_complete(self, body: str) -> Optional[Dict]:
"""Parse the *complete* generated text after generation ends.
Returns ``None`` when no tool calls were found, otherwise a dict
with ``content`` (str or None) and ``tool_calls`` (list of dicts).
"""
@property
@abstractmethod
def has_tool_calls(self) -> bool:
"""True if the parser detected at least one tool call in the stream."""
class ToolParserFactory(BaseFactory["BaseToolParser"]):
pass
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')
def _scan_json(text: str, start: int = 0):
"""Scan for a complete JSON object starting at *start*.
Returns ``(end, complete)`` where *end* is one-past the closing
brace (or ``len(text)`` if unclosed), and *complete* is a bool.
"""
depth = 0
in_string = False
escape = False
for i in range(start, len(text)):
c = text[i]
if escape:
escape = False
continue
if c == "\\":
escape = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
return i + 1, True
return len(text), False
def _parse_tool_call_json(json_str: str, complete: bool):
"""Extract *name* and *arguments* from a tool-call JSON string.
Returns ``(name, args, valid)``.
"""
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', json_str)
if not name_match:
return None, "", False
name = name_match.group(1)
args_match = re.search(r'"arguments"\s*:\s*(.*)', json_str, re.DOTALL)
if not args_match:
return name, "", True
raw = args_match.group(1).rstrip()
if complete and raw.endswith("}"):
raw = raw[:-1].rstrip()
if raw.startswith("{"):
inner = raw[1:].rstrip()
if inner.endswith("}"):
inner = inner[:-1].rstrip()
raw = inner
return name, raw, True
def _find_tool_calls(text: str, start_pos: int = 0):
"""Find all complete ``{...}`` tool-call objects in *text*.
Returns a list of dicts with keys *start*, *end*, *name*, *args*,
*complete*.
"""
results = []
pos = start_pos
while True:
brace = text.find("{", pos)
if brace == -1:
break
end, complete = _scan_json(text, brace)
if not complete:
break
json_str = text[brace:end]
if not _TOOL_CALL_HEAD_RE.search(json_str):
pos = end
continue
name, args, valid = _parse_tool_call_json(json_str, complete=True)
if not valid or name is None:
pos = end
continue
results.append(
{
"start": brace,
"end": end,
"name": name,
"args": args,
"complete": True,
}
)
pos = end
return results
def _find_partial_tool_call(text: str, start_pos: int = 0):
"""Find one incomplete (still-generating) tool-call JSON object."""
brace = text.find("{", start_pos)
if brace == -1:
return None
json_str = text[brace:]
if not _TOOL_CALL_HEAD_RE.search(json_str):
return None
name, args, valid = _parse_tool_call_json(json_str, complete=False)
if not valid or name is None:
return None
return {
"start": brace,
"name": name,
"args": args,
"complete": False,
}
@ToolParserFactory.register("simple_json")
class SimpleJsonToolParser(BaseToolParser):
"""Parser for models that output tool calls as plain JSON objects.
Detects ``{"name": "<func>", "arguments": {...}}`` anywhere in the
generated text. Handles single and (non-overlapping) multiple tool
calls. Text preceding the first tool call is emitted as plain
``content`` deltas.
"""
def __init__(self, tools=None, tool_choice="auto"):
super().__init__(tools, tool_choice)
self._emitted_content_len = 0
self._tc_state: List[Dict] = []
self._has_tool_calls = False
# -------------------------------------------------------------- feed
def feed(
self,
body: str,
current_token_ids: Optional[List[int]] = None,
delta_token_ids: Optional[List[int]] = None,
) -> List[Dict]:
deltas: List[Dict] = []
completed = _find_tool_calls(body)
if not completed:
partial = _find_partial_tool_call(body)
if not partial:
return self._emit_plain_content(body, deltas)
all_tcs = [partial]
else:
all_tcs = completed
partial = _find_partial_tool_call(body, completed[-1]["end"])
if partial:
all_tcs = completed + [partial]
first_start = all_tcs[0]["start"]
if first_start > self._emitted_content_len:
content = body[self._emitted_content_len : first_start]
self._emitted_content_len = first_start
if content:
deltas.append({"content": content})
for i, tc in enumerate(all_tcs):
if i >= len(self._tc_state):
self._tc_state.append(
{
"id": f"call_{uuid.uuid4().hex[:12]}",
"name_emitted": False,
"args_emitted_len": 0,
}
)
self._has_tool_calls = True
st = self._tc_state[i]
if not st["name_emitted"]:
st["name_emitted"] = True
deltas.append(
{
"tool_calls": [
{
"index": i,
"id": st["id"],
"type": "function",
"function": {"name": tc["name"], "arguments": ""},
}
]
}
)
new_args = tc["args"]
if len(new_args) > st["args_emitted_len"]:
diff = new_args[st["args_emitted_len"] :]
st["args_emitted_len"] = len(new_args)
deltas.append(
{
"tool_calls": [
{
"index": i,
"function": {"arguments": diff},
}
]
}
)
return deltas
def _emit_plain_content(self, body: str, deltas: List[Dict]) -> List[Dict]:
new_content = body[self._emitted_content_len :]
if new_content:
self._emitted_content_len = len(body)
deltas.append({"content": new_content})
return deltas
# -------------------------------------------------------- complete
def parse_complete(self, body: str) -> Optional[Dict]:
completed = _find_tool_calls(body)
if not completed:
return None
content = body[: completed[0]["start"]].strip() or None
tool_calls = []
for i, tc in enumerate(completed):
tool_calls.append(
{
"id": f"call_{uuid.uuid4().hex[:12]}",
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["args"],
},
}
)
return {"content": content, "tool_calls": tool_calls}
@property
def has_tool_calls(self) -> bool:
return self._has_tool_calls

View File

@ -42,7 +42,7 @@ class Allocator:
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
def free(self, idx: int, keep_cached: bool = False):
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
@ -51,7 +51,7 @@ class Allocator:
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
def inc_ref(self, idx: int):
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
@ -60,7 +60,7 @@ class Allocator:
with self._lock:
return self._refs[idx]
def touch(self, idx: int) -> None:
def touch(self, idx: int):
with self._lock:
self._lru.move_to_end(idx)
@ -74,7 +74,7 @@ class PrefixCache:
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int) -> None:
def evict(self, idx: int):
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
@ -96,9 +96,7 @@ class PrefixCache:
hits.append(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
@ -127,13 +125,13 @@ class PagePool:
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int) -> None:
def free(self, idx: int):
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None:
def inc_ref(self, idx: int):
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
@ -142,9 +140,7 @@ class PagePool:
self._alloc.touch(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self._prefix.record(page_idx, token_ids, logical_page_idx)
@ -157,7 +153,7 @@ class TaskTable:
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
def set(self, task_id: str, page_table: List[int], cached: int):
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
@ -220,7 +216,7 @@ class Storage:
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
):
seq_len = k.size(1)
if seq_len == 0:
return
@ -286,7 +282,7 @@ class KvcacheView:
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
def write(self, layer_id: int, k: Tensor, v: Tensor):
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
@ -339,7 +335,7 @@ class KVCache:
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
def task_free(self, task_id: str):
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
@ -359,7 +355,7 @@ class KVCache:
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
):
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):

View File

@ -29,9 +29,7 @@ class Executor:
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
if start_pos >= prompt_len:
return

View File

@ -22,14 +22,22 @@ class InferenceScheduler:
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
max_prompt_len: int = 2048,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
self.max_seq_len = max_seq_len or config.max_len
if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
@ -62,22 +70,23 @@ class InferenceScheduler:
dtype=self.dtype,
)
self._running = False
self._stop_event = threading.Event()
self._loop_thread: Optional[threading.Thread] = None
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None:
def remove_task(self, task_id: str):
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
while not self._stop_event.is_set():
finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished:
self._page_cache.task_free(task.task_id)
@ -100,7 +109,10 @@ class InferenceScheduler:
continue
to_prefill = [
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
t
for t in self._task_mgr.get_active_tasks()
if t.output_tokens == 0
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
]
if to_prefill:
for t in to_prefill:
@ -148,11 +160,15 @@ class InferenceScheduler:
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._page_cache.task_extend(t.task_id, pos)
extend_ok = self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
if not extend_ok:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
for t in valid:
if t.is_finished(stop_ids):
@ -160,28 +176,38 @@ class InferenceScheduler:
t.stream_callback(STOP)
except Exception as e:
self._stop_event.set()
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
raise
def start(self) -> None:
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def start(self):
if self._loop_thread is not None and self._loop_thread.is_alive():
return
self._stop_event.clear()
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self) -> None:
self._running = False
def stop(self):
self._stop_event.set()
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
if self._loop_thread is not None:
self._loop_thread.join(timeout=2.0)
self._loop_thread = None
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -172,12 +172,12 @@ class TaskManager:
to_add.append(self.waiting_queue.popleft())
return to_add
def activate(self, task: Task) -> None:
def activate(self, task: Task):
task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]) -> None:
def return_to_waiting(self, tasks: List[Task]):
with self._lock:
for task in reversed(tasks):
self.waiting_queue.appendleft(task)
@ -185,18 +185,25 @@ class TaskManager:
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0) -> None:
self._task_event.clear()
def wait_for_tasks(self, timeout: float = 1.0):
with self._lock:
if self.waiting_queue or self.active_tasks:
return
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:
with self._lock:
return list(self.active_tasks)
def clear_queues(self) -> None:
def get_waiting_tasks(self) -> List[Task]:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self):
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self) -> None:
def wake(self):
self._task_event.set()

View File

@ -13,17 +13,6 @@ from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer
def _validate_sampling_params(
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
@ -59,7 +48,7 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0) -> None:
def wait_completion(self, timeout: float = 300.0):
with self._cond:
if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
@ -86,7 +75,12 @@ class GenerationRequest:
max_tokens: Optional[int] = None,
stream: bool = False,
):
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature > 0):
raise ValueError("temperature must be a positive number")
self.messages = messages
self.top_k = top_k
@ -137,7 +131,6 @@ class InferenceEngine:
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
@ -158,7 +151,6 @@ class InferenceEngine:
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
@ -289,7 +281,7 @@ class InferenceEngine:
def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats()
def shutdown(self) -> None:
def shutdown(self):
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -29,6 +29,7 @@ class BaseSamplingStrategy(ABC):
Returns:
Transformed logits tensor.
"""
raise NotImplementedError
class TemperatureStrategy(BaseSamplingStrategy):
@ -41,13 +42,15 @@ class TemperatureStrategy(BaseSamplingStrategy):
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
t = self.temperature
if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
t = torch.clamp(t, min=1e-8)
if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
logits = logits / t
elif t != 1.0:
logits = logits / t
logits = logits / max(t, 1e-8)
return logits
@ -61,7 +64,7 @@ class TopKStrategy(BaseSamplingStrategy):
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
tk = self.top_k
if isinstance(tk, Tensor):
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
@ -98,7 +101,9 @@ class TopPStrategy(BaseSamplingStrategy):
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
def _apply(
self, logits: Tensor, top_p: Union[float, Tensor], filter_value: float
) -> Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
@ -109,7 +114,7 @@ class TopPStrategy(BaseSamplingStrategy):
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
@ -140,7 +145,7 @@ class SamplingPipeline(BaseSamplingStrategy):
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits

View File

@ -1,12 +1,18 @@
from astrai.model.automodel import AutoModel
from astrai.model.module import (
GQA,
MLP,
DecoderBlock,
Linear,
RMSNorm,
from astrai.model.components.attention import GQA
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.lora import (
LoRAConfig,
inject_lora,
load_lora,
merge_lora,
save_lora,
)
from astrai.model.transformer import Transformer
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
__all__ = [
# Modules
@ -16,6 +22,13 @@ __all__ = [
"GQA",
"DecoderBlock",
# Models
"Transformer",
"AutoRegressiveLM",
"EmbeddingEncoder",
"AutoModel",
# LoRA
"LoRAConfig",
"inject_lora",
"merge_lora",
"save_lora",
"load_lora",
]

View File

@ -6,16 +6,20 @@ from contextlib import contextmanager
from pathlib import Path
from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig
from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.factory import BaseFactory
from astrai.serialization import load_model_config, load_model_weights, save_model
@contextmanager
def _disable_random_init(enable: bool = True):
init_functions = [
if not enable:
yield
return
names = (
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
"constant_",
"normal_",
"uniform_",
]
original_funcs = {}
for name in init_functions:
if enable and hasattr(nn.init, name):
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
)
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
for n in orig:
setattr(nn.init, n, lambda *a, **kw: None)
try:
yield
finally:
if enable:
for name, orig_func in original_funcs.items():
setattr(nn.init, name, orig_func)
for n, fn in orig.items():
setattr(nn.init, n, fn)
class AutoModel(BaseFactory["AutoModel"], nn.Module):
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
Provides model loading/saving, registration, and generation.
"""
def __init__(self, config: ModelConfig):
def __init__(self, config: BaseModelConfig):
super().__init__()
self.config = config
@ -59,24 +60,22 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path)
# Load config
config = ModelConfig()
config_path = model_path / "config.json"
if config_path.exists():
config.load(str(config_path))
else:
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer"
raw = load_model_config(str(model_path))
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)
# Load weights
weights_path = model_path / "model.safetensors"
if weights_path.exists():
state_dict = st.load_file(str(weights_path))
state_dict = load_model_weights(str(model_path))
model.load_state_dict(state_dict, strict=strict)
return model
@ -84,15 +83,12 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
def save_pretrained(
self,
save_directory: Union[str, Path],
) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
# Save config
self.config.save(str(save_path / "config.json"))
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
):
save_model(
config=self.config.to_dict(),
state_dict=self.state_dict(),
save_directory=str(save_directory),
)
def to(self, *args, **kwargs) -> Self:
"""Move model to device/dtype."""

View File

@ -0,0 +1,25 @@
from astrai.model.components.attention import GQA, MLA, repeat_kv
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import (
RotaryEmbedding,
apply_rotary_emb,
get_rotary_emb,
)
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"Embedding",
"GQA",
"MLA",
"DecoderBlock",
"RotaryEmbedding",
"apply_rotary_emb",
"get_rotary_emb",
"repeat_kv",
]

View File

@ -5,11 +5,14 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import apply_rotary_emb
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""Repeat KV heads n_rep times for GQA."""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
@ -20,88 +23,11 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
)
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
class AttnFactory(BaseFactory[nn.Module]):
pass
@AttnFactory.register("gqa")
class GQA(nn.Module):
def __init__(
self,
@ -152,7 +78,6 @@ class GQA(nn.Module):
) -> Tensor:
is_causal = attn_mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -167,7 +92,6 @@ class GQA(nn.Module):
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
@ -183,6 +107,7 @@ class GQA(nn.Module):
return out
@AttnFactory.register("mla")
class MLA(nn.Module):
def __init__(
self,
@ -193,6 +118,7 @@ class MLA(nn.Module):
qk_nope_head_dim: int,
qk_rope_head_dim: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
@ -206,16 +132,20 @@ class MLA(nn.Module):
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
self.layer_id = layer_id
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
n_kv_heads * (2 * self.head_dim),
)
self.o_proj = Linear(dim, dim, bias=False)
@ -248,7 +178,7 @@ class MLA(nn.Module):
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_rope_head_dim :],
q[..., self.qk_nope_head_dim :],
)
q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb)
@ -256,6 +186,10 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
@ -274,57 +208,3 @@ class MLA(nn.Module):
out = self.o_proj(attn_out)
return out
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -0,0 +1,59 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import AttnFactory
from astrai.model.components.mlp import FFNFactory
from astrai.model.components.norm import RMSNorm
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**kwargs,
):
super().__init__()
self.attention = AttnFactory.create(
attn_type,
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
**kwargs,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -0,0 +1,23 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
self.neftune_noise_alpha = 0.0
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
out = F.embedding(x, self.weight)
if self.training and self.neftune_noise_alpha > 0.0:
eps = self.neftune_noise_alpha / math.sqrt(out.size(1))
out = out + eps * torch.randn_like(out)
return out

View File

@ -0,0 +1,21 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / (fan_in**0.5)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -0,0 +1,194 @@
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional, Set
import torch
import torch.nn as nn
import torch.nn.functional as F
from astrai.model.components.linear import Linear
from astrai.serialization import (
load_json,
load_safetensors,
save_json,
save_safetensors,
)
logger = logging.getLogger(__name__)
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
TARGET_MODULES_FFN = {"up", "gate", "down"}
@dataclass
class LoRAConfig:
r: int = 16
alpha: int = 32
target_modules: tuple = ("q_proj", "v_proj")
class LoRALinear(nn.Module):
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
super().__init__()
self.register_parameter("weight", base.weight)
self.weight.requires_grad_(False)
self.bias = base.bias
if self.bias is not None:
self.bias.requires_grad_(False)
self.r = r
self.scaling = alpha / r
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
self._merged = False
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
if not self._merged:
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
return out
def merge(self):
if self._merged:
return
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self._merged = True
del self.lora_A
del self.lora_B
def _collect_lora_info(model: nn.Module) -> dict:
names = {}
for n, m in model.named_modules():
if isinstance(m, Linear):
_, _, child = n.rpartition(".")
names.setdefault(child, []).append(n)
return names
def _get_lora_count(model: nn.Module) -> int:
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
def inject_lora(
model: nn.Module,
r: int = 16,
alpha: int = 32,
target_modules: Optional[Set[str]] = None,
) -> LoRAConfig:
if target_modules is None:
target_modules = TARGET_MODULES_ATTN
available = _collect_lora_info(model)
injected = 0
for name, module in list(model.named_modules()):
if not isinstance(module, Linear):
continue
parent_name, _, child_name = name.rpartition(".")
if child_name not in target_modules:
continue
parent = model.get_submodule(parent_name) if parent_name else model
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
injected += 1
if injected == 0:
logger.warning(
"No LoRA layers injected. Available Linear child names: %s. "
"target_modules: %s. Check model type and target_modules.",
sorted(available),
sorted(target_modules),
)
else:
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
def merge_lora(model: nn.Module):
n = 0
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
n += 1
if n == 0:
logger.warning("No LoRA layers to merge.")
else:
logger.info("Merged %d LoRA layers", n)
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
lora_sd = {
k: v
for k, v in model.state_dict().items()
if k.endswith((".lora_A", ".lora_B"))
}
if not lora_sd:
raise RuntimeError(
"No LoRA parameters found in model. "
"The model may not have been injected or was already merged."
)
path = Path(save_dir)
path.mkdir(parents=True, exist_ok=True)
save_safetensors(lora_sd, path / "adapter_model.safetensors")
save_json(asdict(config), path / "adapter_config.json")
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
path = Path(load_dir)
raw = load_json(path / "adapter_config.json")
config = LoRAConfig(
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
)
existing = _get_lora_count(model)
if existing > 0:
logger.warning(
"Model already has %d LoRA layers. Skipping injection, "
"loading weights onto existing layers only.",
existing,
)
else:
inject_lora(
model,
r=config.r,
alpha=config.alpha,
target_modules=set(config.target_modules),
)
weights = load_safetensors(path / "adapter_model.safetensors")
try:
missing, unexpected = model.load_state_dict(weights, strict=False)
except RuntimeError as e:
msg = str(e)
if "size mismatch" in msg:
raise RuntimeError(
f"LoRA weight shapes do not match the model. "
f"The adapter config (r={config.r}) may not match the injected layers. "
f"Original error: {msg}"
) from e
raise
injected = _get_lora_count(model)
if injected == 0:
raise RuntimeError(
"No LoRA layers found after loading. "
"Inject LoRA before calling load_lora, or check the adapter config."
)
if missing:
lora_missing = [k for k in missing if "lora" in k]
if lora_missing:
raise RuntimeError(
f"LoRA weight keys not found in model: {lora_missing}. "
f"The adapter config (r={config.r}) may not match the model."
)
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
if unexpected:
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
logger.info("LoRA adapter loaded from %s", load_dir)
return config

View File

@ -0,0 +1,91 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
pass
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.up = Linear(dim, dim_ffn)
self.gate = Linear(dim, dim_ffn)
self.down = Linear(dim_ffn, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@FFNFactory.register("moe")
class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_ffn: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
):
super().__init__()
self.dim = dim
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.n_activated_experts = n_activated_experts
self.topk_method = topk_method
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:
bsz, seq_len, dim = x.shape
x_flat = x.view(-1, dim)
shared_out = self._shared_forward(x_flat)
routed_out = self._routed_forward(x_flat)
out = (shared_out + routed_out).view(bsz, seq_len, dim)
return out
def _shared_forward(self, x: Tensor) -> Tensor:
if self.n_shared_experts == 0:
return torch.zeros_like(x)
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
def _routed_forward(self, x: Tensor) -> Tensor:
N, D = x.shape
K = self.n_activated_experts
router_logits = self.router(x)
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for expert_idx in range(self.n_routed_experts):
expert_mask = topk_indices == expert_idx
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
expert_input = x[token_idx]
expert_output = self.routed_experts[expert_idx](expert_input)
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
output.index_add_(0, token_idx, expert_output * weights)
return output

View File

@ -0,0 +1,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)

View File

@ -0,0 +1,71 @@
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def ntk_base(base: float, dim: int, factor: float) -> float:
return base * (factor ** (dim / (dim - 2)))
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim: int,
max_len: int,
base: float = 10000,
rope_scaling: Optional[Dict] = None,
):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self.rope_scaling = rope_scaling
if rope_scaling is not None:
scaling_type = rope_scaling.get("type", "ntk")
factor = rope_scaling.get("factor", 1.0)
if scaling_type == "ntk":
self.base = ntk_base(base, dim, factor)
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)

99
astrai/model/encoder.py Normal file
View File

@ -0,0 +1,99 @@
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import EncoderConfig
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
from astrai.model.transformer import process_attention_mask
@AutoModel.register("embedding")
class EmbeddingEncoder(AutoModel):
def __init__(self, config: EncoderConfig):
super().__init__(config)
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.pooling_type = config.pooling_type or "mean"
self.normalize_embeddings = config.normalize_embeddings or False
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
state_dict = dict(state_dict)
state_dict.pop("lm_head.weight", None)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
B, S = input_ids.shape
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
hidden_states = self.norm(x)
if self.pooling_type == "cls":
pooled = hidden_states[:, 0]
elif self.pooling_type == "last":
if input_mask is not None:
lengths = input_mask.sum(dim=1) - 1
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
else:
pooled = hidden_states[:, -1]
else:
if input_mask is not None:
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
min=1.0
)
else:
pooled = hidden_states.mean(dim=1)
if self.normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
return pooled

View File

@ -1,19 +1,17 @@
from typing import Any, Mapping, Optional
from typing import Any, Dict, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
Embedding,
Linear,
RMSNorm,
RotaryEmbedding,
)
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
def process_attention_mask(
@ -28,35 +26,38 @@ def process_attention_mask(
return input_mask
device = input_tensor.device
dtype = input_tensor.dtype
B, S = input_tensor.size()[:2]
B = input_tensor.size(0)
T = position_ids.max().item() + 1
if input_mask is None:
if position_ids.min().item() == 0 and is_causal:
return None
pad = torch.ones(B, T, dtype=torch.bool, device=device)
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
attend = pad.view(B, 1, T).expand(B, S, T).clone()
if is_causal:
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
attend = attend & causal
return torch.full(
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
).masked_fill_(attend.unsqueeze(1), 0.0)
return attend.unsqueeze(1)
@AutoModel.register("transformer")
class Transformer(AutoModel):
"""Transformer language model with paged KV cache."""
@AutoModel.register("autoregressive_lm")
class AutoRegressiveLM(AutoModel):
"""Autoregressive language model with paged KV cache."""
def __init__(self, config: ModelConfig):
def __init__(self, config: AutoRegressiveLMConfig):
super().__init__(config)
self.config = config
rope_dim = (
config.qk_rope_head_dim
if config.attn_type == "mla"
else config.dim // config.n_heads
)
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(
config.dim // config.n_heads, config.max_len
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
@ -71,6 +72,15 @@ class Transformer(AutoModel):
config.use_qk_norm,
config.use_gated_attention,
layer_id,
attn_type=config.attn_type,
ffn_type=config.ffn_type,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.topk_method,
kv_lora_rank=config.kv_lora_rank,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
)
for layer_id in range(config.n_layers)
]
@ -79,15 +89,14 @@ class Transformer(AutoModel):
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight:
if self.config.tie_weight is True:
self.lm_head.weight = self.embed_tokens.weight
self._init_weights()
self.apply(self._init_weights)
def _init_weights(self):
for param in self.parameters():
if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = "lm_head.weight"
@ -95,7 +104,7 @@ class Transformer(AutoModel):
state_dict = dict(state_dict)
if self.config.tie_weight:
if self.config.tie_weight is True:
# same tensor for embed and lm_head
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
@ -111,7 +120,7 @@ class Transformer(AutoModel):
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight:
if self.config.tie_weight is True:
lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict:
del state_dict[lm_head_key]
@ -124,7 +133,7 @@ class Transformer(AutoModel):
input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
) -> Dict[str, Tensor]:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)

View File

@ -1,3 +1,13 @@
from astrai.parallel.executor import (
AccumOptimizer,
AccumScheduler,
BaseExecutor,
DDPExecutor,
ExecutorFactory,
FSDPExecutor,
GradientState,
NoneExecutor,
)
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import (
get_current_device,
@ -17,4 +27,12 @@ __all__ = [
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear",
"ExecutorFactory",
"BaseExecutor",
"GradientState",
"AccumOptimizer",
"AccumScheduler",
"NoneExecutor",
"DDPExecutor",
"FSDPExecutor",
]

272
astrai/parallel/executor.py Normal file
View File

@ -0,0 +1,272 @@
"""Unified training executor — parallel strategy + gradient accumulation."""
import contextlib
import logging
import os
from contextlib import contextmanager
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.factory import BaseFactory
from astrai.parallel.setup import get_rank, get_world_size
logger = logging.getLogger(__name__)
class GradientState:
def __init__(self, grad_accum_steps: int = 1):
self.num_steps = max(grad_accum_steps, 1)
self._step: int = 0
self._sync_gradients: bool = True
@property
def sync_gradients(self) -> bool:
return self._sync_gradients
def _do_sync(self):
self._step += 1
self._sync_gradients = self._step % self.num_steps == 0
class AccumOptimizer:
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
self.optimizer = optimizer
self.gradient_state = gradient_state
def step(self, closure=None):
if self.gradient_state.sync_gradients:
self.optimizer.step(closure)
def zero_grad(self):
if self.gradient_state.sync_gradients:
self.optimizer.zero_grad()
@property
def param_groups(self):
return self.optimizer.param_groups
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, d):
self.optimizer.load_state_dict(d)
class AccumScheduler:
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
self.scheduler = scheduler
self.gradient_state = gradient_state
def step(self):
if self.gradient_state.sync_gradients:
self.scheduler.step()
def state_dict(self):
return self.scheduler.state_dict()
def load_state_dict(self, d):
self.scheduler.load_state_dict(d)
def get_last_lr(self):
return self.scheduler.get_last_lr()
class BaseExecutor:
def __init__(self, grad_accum_steps: int = 1):
self.gradient_state = GradientState(grad_accum_steps)
def prepare(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
dataloader: Optional[DataLoader] = None,
scheduler: Optional[LRScheduler] = None,
) -> Tuple[
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
]:
model = self._prepare_model(model)
if optimizer is not None:
optimizer = AccumOptimizer(optimizer, self.gradient_state)
if scheduler is not None:
scheduler = AccumScheduler(scheduler, self.gradient_state)
return model, optimizer, dataloader, scheduler
def _prepare_model(self, model: nn.Module) -> nn.Module:
return model
def _no_sync(self, model: nn.Module):
return contextlib.nullcontext()
@contextmanager
def accumulate(self, model: nn.Module):
self.gradient_state._do_sync()
if not self.gradient_state.sync_gradients:
with self._no_sync(model):
yield
else:
yield
def backward(self, loss: torch.Tensor):
loss.backward()
def unwrap_model(self, model: nn.Module):
return model.state_dict()
@property
def use_distributed(self) -> bool:
return get_world_size() > 1
@property
def sync_gradients(self) -> bool:
return self.gradient_state.sync_gradients
@property
def grad_accum_steps(self) -> int:
return self.gradient_state.num_steps
class ExecutorFactory(BaseFactory[BaseExecutor]):
pass
@ExecutorFactory.register("none")
class NoneExecutor(BaseExecutor):
pass
@ExecutorFactory.register("ddp")
class DDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
dim: int = 0,
broadcast_buffers: bool = True,
init_sync: bool = True,
process_group=None,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
delay_all_reduce_named_params=None,
param_to_hook_all_reduce=None,
mixed_precision=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._ddp_kwargs = dict(
dim=dim,
broadcast_buffers=broadcast_buffers,
init_sync=init_sync,
process_group=process_group,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
delay_all_reduce_named_params=delay_all_reduce_named_params,
param_to_hook_all_reduce=param_to_hook_all_reduce,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
)
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("DDP backend selected but world_size=1, model not wrapped")
return model
local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
**self._ddp_kwargs,
)
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, DDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=True,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
return model
self._original_model = model
device_id = torch.device("cuda", get_rank())
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, FSDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, FSDP) and self.use_distributed:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
return model.state_dict()
return model.state_dict()

View File

@ -1,4 +1,5 @@
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import wraps
from typing import Callable
@ -30,6 +31,7 @@ def get_rank() -> int:
def setup_parallel(
rank: int,
world_size: int,
local_rank: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
@ -41,14 +43,18 @@ def setup_parallel(
return
if world_size <= 1:
device_id = torch.device(device_type, local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
yield None
return
device_id = torch.device(device_type, rank)
device_id = torch.device(device_type, local_rank)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id)
@ -90,7 +96,7 @@ def only_on_rank(rank, sync=False):
return decorator
def wrapper_spawn_func(
def _run_single_rank(
rank: int,
world_size: int,
backend: str,
@ -100,20 +106,108 @@ def wrapper_spawn_func(
func: Callable,
kwargs: dict,
):
try:
with setup_parallel(
rank=rank,
world_size=world_size,
local_rank=rank,
backend=backend,
master_addr=master_addr,
master_port=master_port,
device_type=device_type,
):
func(**kwargs)
class LaunchStrategy(ABC):
"""Strategy for launching a function in a distributed context."""
def __init__(
self,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
device_type: str,
start_method: str,
):
self.world_size = world_size
self.backend = backend
self.master_addr = master_addr
self.master_port = master_port
self.device_type = device_type
self.start_method = start_method
@abstractmethod
def launch(self, func: Callable, **kwargs):
raise NotImplementedError
class TorchrunStrategy(LaunchStrategy):
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
def launch(self, func: Callable, **kwargs):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=master_addr,
master_port=master_port,
device_type=device_type,
local_rank=local_rank,
backend=self.backend,
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
master_port=os.environ.get("MASTER_PORT", self.master_port),
device_type=self.device_type,
):
func(**kwargs)
except Exception as e:
print(f"Error in rank {rank}: {e}")
raise
class LocalStrategy(LaunchStrategy):
"""Local launcher — single-process or mp.start_processes."""
def launch(self, func: Callable, **kwargs):
args = (
self.world_size,
self.backend,
self.master_addr,
self.master_port,
self.device_type,
func,
kwargs,
)
if self.world_size == 1:
_run_single_rank(0, *args)
return
ctx = mp.start_processes(
_run_single_rank,
args=args,
nprocs=self.world_size,
start_method=self.start_method,
join=False,
)
try:
while not ctx.join():
pass
except BaseException:
for p in ctx.processes:
p.terminate()
ctx.join()
raise
def _detect_launcher() -> str:
"""Detect the distributed launcher from environment.
Returns one of: "torchelastic", "torchrun", "external", "local".
"""
if dist.is_torchelastic_launched():
return "torchelastic"
if "LOCAL_WORLD_SIZE" in os.environ:
return "torchrun"
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
return "external"
return "local"
def spawn_parallel_fn(
@ -123,39 +217,16 @@ def spawn_parallel_fn(
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
start_method: str = "spawn",
**kwargs,
):
# clear environment variables
for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]
if world_size == 1:
device_id = torch.device(device_type, 0)
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return
wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
func,
kwargs,
)
mp.spawn(
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
)
launcher = _detect_launcher()
if launcher in ("torchelastic", "torchrun", "external"):
strategy = TorchrunStrategy(
world_size, backend, master_addr, master_port, device_type, start_method
)
else:
strategy = LocalStrategy(
world_size, backend, master_addr, master_port, device_type, start_method
)
strategy.launch(func, **kwargs)

View File

@ -0,0 +1,32 @@
from astrai.preprocessing.builder import (
BaseMaskBuilder,
MaskBuilderFactory,
SectionedMaskBuilder,
)
from astrai.preprocessing.packing import (
PackingStrategy,
PackingStrategyFactory,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from astrai.preprocessing.position_id import (
PositionIdStrategy,
PositionIdStrategyFactory,
)
from astrai.preprocessing.writer import (
StoreWriter,
StoreWriterFactory,
)
__all__ = [
"BaseMaskBuilder",
"MaskBuilderFactory",
"PackingStrategy",
"PackingStrategyFactory",
"Pipeline",
"PositionIdStrategy",
"PositionIdStrategyFactory",
"SectionedMaskBuilder",
"StoreWriter",
"StoreWriterFactory",
"filter_by_length",
]

View File

@ -0,0 +1,315 @@
"""Mask building for preprocessing pipeline.
:class:`SectionRenderer` converts section specs into token ids and loss
masks (template / text / value extraction). :class:`SectionedMaskBuilder`
orchestrates single-output / multi-output (DPO / GRPO) assembly.
"""
from abc import ABC, abstractmethod
from typing import Optional
from astrai.factory import BaseFactory
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
val = item.get(domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
def _resolve_action(action: str, role: str, config) -> str:
if action == "$role":
return config.mask.get(role, config.mask_default)
return action
class SectionRenderer:
"""Render section specs into ``(ids, loss_mask)`` tuples."""
def process_sections(
self,
item: dict,
sections: list,
config,
tokenizer,
*,
is_top_level: bool = False,
):
all_ids: list[int] = []
loss_mask: list[int] = []
has_template = any(s.get("template") for s in sections)
is_text_config = not has_template and all(
s["action"] == "train" for s in sections
)
if is_top_level and has_template and tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
loss_mask.append(0)
first_section = True
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
add_special = sec.get(
"add_special_tokens", not use_template and first_section
)
if use_template:
success = self._append_template(
item, field, action, tokenizer, config, all_ids, loss_mask
)
if not success:
continue
else:
success = self._append_text(
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
)
if not success:
continue
first_section = False
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
if is_top_level and has_template and len(all_ids) <= 1:
return None, None
return all_ids, loss_mask
def process_list_field(self, item: dict, sections: list, config, tokenizer):
all_ids: list[int] = []
loss_mask: list[int] = []
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
values = item.get(field)
if not isinstance(values, list):
continue
for val in values:
if use_template:
if isinstance(val, list):
wrapper = {field: val}
self._append_template(
wrapper,
field,
action,
tokenizer,
config,
all_ids,
loss_mask,
)
else:
wrapper = {field: str(val)}
self._append_text(
wrapper,
field,
action,
tokenizer,
False,
False,
config,
all_ids,
loss_mask,
)
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
return all_ids, loss_mask
@staticmethod
def is_value_section(sections: list) -> bool:
return len(sections) == 1 and sections[0].get("action") == "value"
@staticmethod
def extract_raw_value(item: dict, sections: list):
sec = sections[0]
field = sec["field"]
raw = item.get(field)
if raw is None:
return None
if isinstance(raw, list):
return [float(v) for v in raw]
return [float(raw)]
def _append_template(
self, item, field, action, tokenizer, config, all_ids, loss_mask
):
messages = item.get(field)
if not isinstance(messages, list) or not messages:
return False
for msg in messages:
role = msg.get("role", "")
act = _resolve_action(action, role, config)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
all_ids.extend(ids)
val = 1 if act == "train" else 0
loss_mask.extend([val] * len(ids))
return True
def _append_text(
self,
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
):
text = str(item.get(field, ""))
if not text.strip():
return False
if is_text_config:
pp = config.preprocessing
if pp.min_chars > 0 and len(text) < pp.min_chars:
return False
if len(text) > pp.max_chars:
return False
ids = tokenizer.encode(text, add_special_tokens=add_special)
all_ids.extend(ids)
val = 1 if action == "train" else 0
loss_mask.extend([val] * len(ids))
return True
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]: ...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
pass
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder supporting single and multi-output modes.
Single-output::
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
{"sequence": [...], "loss_mask": [...], "domain": "..."}
Multi-output (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [{"field": "chosen", "action": "$role", "template": true}]},
"rejected": {"sections": [{"field": "rejected", "action": "$role", "template": true}]},
}}}
{"chosen": [...], "chosen_mask": [...], "rejected": [...], "rejected_mask": [...], "domain": "..."}
Output spec fields::
sections list of section specs (same format as single-output)
list_field True when JSONL field holds a list (GRPO responses)
mask_key explicit loss-mask output key (default: ``"{output_key}_mask"``)
"""
def __init__(self):
self.renderer = SectionRenderer()
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
sources_spec = getattr(config.input, "sources", None)
if sources_spec:
return self._build_multi(item, sources_spec, config, tokenizer)
return self._build_single(item, config, tokenizer)
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
sections = config.input.sections
if not sections:
return None
ids, mask = self.renderer.process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
return None
result: dict = {
"sequence": ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in mask):
result["loss_mask"] = mask
return result
def _build_multi(
self, item: dict, sources_spec: dict, config, tokenizer
) -> Optional[dict]:
result: dict = {}
any_output = False
for output_key, spec in sources_spec.items():
sections = spec.get("sections", [])
if not sections:
continue
if self.renderer.is_value_section(sections):
ids = self.renderer.extract_raw_value(item, sections)
if ids is None:
continue
result[output_key] = ids
any_output = True
continue
list_field = spec.get("list_field", False)
mask_key = spec.get("mask_key", f"{output_key}_mask")
if list_field:
ids, mask = self.renderer.process_list_field(
item, sections, config, tokenizer
)
else:
ids, mask = self.renderer.process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
continue
result[output_key] = ids
if not all(m == 1 for m in mask):
result[mask_key] = mask
elif "mask_key" in spec:
result[mask_key] = mask
any_output = True
if not any_output:
return None
result["domain"] = _extract_domain(item, config.output.domain_key)
return result

View File

@ -0,0 +1,121 @@
"""Sequence packing strategies for shard-level reordering and truncation.
Each strategy receives the accumulated ``{key: [list of token lists]}``
dict for a shard and returns a reordered / truncated version. The
pipeline later flattens the result into contiguous tensors.
"""
from abc import ABC, abstractmethod
from typing import Dict, List
from astrai.factory import BaseFactory
def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
class PackingStrategy(ABC):
"""Reorder and truncate sequences within a shard."""
@abstractmethod
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
raise NotImplementedError
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
pass
@PackingStrategyFactory.register("simple")
class SimplePacking(PackingStrategy):
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
return {
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
for k, vals in keys.items()
}
@PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy):
"""Best-Fit Decreasing bin packing.
Assigns sequences to bins using a best-fit heuristic (sorted by
decreasing length) and concatenates sequences within each bin into
a single packed sequence. Packed sequences are truncated to
*max_packed_len* so that each packed bin fits within one context
window during training.
"""
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
sequences = keys.get("sequence", [])
if not sequences:
return keys
bins = self._plan(sequences, max_packed_len, truncation_mode)
packed: Dict[str, List[List[int]]] = {}
for k, vals in keys.items():
packed[k] = [
_truncate(
self._concat_bin(vals, bin_indices),
max_packed_len,
truncation_mode,
)
for bin_indices in bins
]
return packed
@staticmethod
def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]:
result: List[int] = []
for i in indices:
result.extend(vals[i])
return result
@staticmethod
def _plan(
sequences: List[List[int]], max_packed_len: int, truncation_mode: str
) -> List[List[int]]:
n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = len(
_truncate(sequences[orig_idx], max_packed_len, truncation_mode)
)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
return bins

View File

@ -0,0 +1,185 @@
"""Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
sharding and flush to ``.h5`` / ``.bin`` storage. Packing, position-id
generation and storage writing are each delegated to pluggable strategies,
dispatched by configuration keys.
"""
import json
import logging
import os
from collections import defaultdict
from itertools import chain
from typing import Dict, List, Optional
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.preprocessing.builder import MaskBuilderFactory
from astrai.preprocessing.packing import PackingStrategyFactory
from astrai.preprocessing.position_id import PositionIdStrategyFactory
from astrai.preprocessing.writer import StoreWriterFactory
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
Usage::
config = PipelineConfig.from_file("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
"""
def __init__(
self,
config: PipelineConfig,
input_paths: list[str],
output_dir: str,
tokenizer_path: str,
):
os.makedirs(output_dir, exist_ok=True)
self.config = config
self.paths = input_paths
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = MaskBuilderFactory.create("sectioned")
self._packer = PackingStrategyFactory.create(
config.preprocessing.packing_strategy
)
self._position_id = PositionIdStrategyFactory.create(
config.output.position_ids_mode
)
self._writer = StoreWriterFactory.create(config.output.storage_format)
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
domains: dict = defaultdict(lambda: defaultdict(list))
total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int)
count = 0
pp = self.config.preprocessing
for item in tqdm.tqdm(
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
):
if pp.max_items and count >= pp.max_items:
break
try:
result = self.transform(item)
except Exception:
logger.warning(
"Failed to process item #%d, skipping", count + 1, exc_info=True
)
continue
if result is None:
continue
domain = result.pop("domain", "__default__")
is_multi = bool(getattr(self.config.input, "sources", None))
if is_multi:
ids = self._primary_ids(result)
else:
ids = result.pop("sequence")
result["sequence"] = ids
if not ids:
continue
bucket = domains[domain]
self._align_bucket(bucket, result, ids)
for key, val in result.items():
bucket[key].append(val)
count += 1
total_tokens += len(ids)
if total_tokens >= self.config.output.max_tokens_per_shard:
self._flush(domains, shard_idx)
domains.clear()
total_tokens = 0
if total_tokens > 0:
self._flush(domains, shard_idx)
@staticmethod
def _primary_ids(result: dict) -> list:
"""Return the first list-valued entry in *result* as the primary id
sequence for token counting."""
for val in result.values():
if isinstance(val, list) and val and isinstance(val[0], int):
return val
return []
@staticmethod
def _align_bucket(bucket: dict, result: dict, ids: list):
"""Pad previously-accumulated keys that are missing from *result*."""
for key in list(bucket.keys()):
if key in result:
continue
bucket[key].append([1] * len(ids))
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
pp = self.config.preprocessing
keys = self._packer.apply(dict(keys), pp.max_packed_len, pp.truncation_mode)
tensors: Dict[str, List[torch.Tensor]] = {}
for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get(
self.config.output.dtype.get(key, "int32"), torch.int32
)
tensors[key] = [
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
]
pos_ids = self._position_id.generate(keys.get("sequence", []))
if pos_ids:
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
self._writer.save(self.output_dir, domain, idx, tensors)
shard_idx[domain] = idx + 1
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors[first_key][0].numel():,} tokens)"
)

View File

@ -0,0 +1,46 @@
"""Position-id generation strategies for packed sequences.
Each strategy takes the list of per-document token sequences after packing
and returns a flat list of position ids (same total length as all
sequences combined). The pipeline wraps the result into a tensor and
attaches it as ``position_ids``.
"""
from abc import ABC, abstractmethod
from typing import List
from astrai.factory import BaseFactory
class PositionIdStrategy(ABC):
"""Generate ``position_ids`` for packed sequences."""
@abstractmethod
def generate(self, sequences: List[List[int]]) -> List[int]:
raise NotImplementedError
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
pass
@PositionIdStrategyFactory.register("none")
class NoPositionId(PositionIdStrategy):
def generate(self, sequences: List[List[int]]) -> List[int]:
return []
@PositionIdStrategyFactory.register("doc_reset")
class DocResetPositionId(PositionIdStrategy):
def generate(self, sequences: List[List[int]]) -> List[int]:
pos_ids = []
for seq in sequences:
pos_ids.extend(range(len(seq)))
return pos_ids
@PositionIdStrategyFactory.register("continuous")
class ContinuousPositionId(PositionIdStrategy):
def generate(self, sequences: List[List[int]]) -> List[int]:
total = sum(len(seq) for seq in sequences)
return list(range(total))

View File

@ -0,0 +1,75 @@
"""Storage writer strategies for pipeline output.
The :class:`StoreWriter` abstraction decouples the pipeline from the
concrete storage format (bin / h5). The pipeline builds a ``{key:
List[Tensor]}`` dict and delegates the write to the writer selected
by ``output.storage_format``.
"""
import logging
import os
import shutil
from abc import ABC, abstractmethod
from typing import Dict, List
import torch
from astrai.dataset.storage import save_bin, save_h5
from astrai.factory import BaseFactory
logger = logging.getLogger(__name__)
class StoreWriter(ABC):
"""Write pre-tokenized tensors to disk in a format-specific way."""
@abstractmethod
def save(
self,
output_dir: str,
domain: str,
shard_idx: int,
tensors: Dict[str, List[torch.Tensor]],
) -> None: ...
class StoreWriterFactory(BaseFactory["StoreWriter"]):
pass
@StoreWriterFactory.register("bin")
class BinWriter(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
try:
save_bin(shard_path, tensors)
except Exception:
if os.path.exists(shard_path):
shutil.rmtree(shard_path, ignore_errors=True)
logger.error(
"Failed to write shard %s/%s_%04d, cleaned up partial output",
domain,
"shard",
shard_idx,
exc_info=True,
)
raise
@StoreWriterFactory.register("h5")
class H5Writer(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
chunk_dir = os.path.join(output_dir, domain)
file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5")
try:
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
except Exception:
if os.path.exists(file_path):
os.remove(file_path)
logger.error(
"Failed to write shard %s/data_%04d.h5, cleaned up partial output",
domain,
shard_idx,
exc_info=True,
)
raise

21
astrai/protocols.py Normal file
View File

@ -0,0 +1,21 @@
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class OptimizerProtocol(Protocol):
def step(self, closure=None): ...
def zero_grad(self): ...
@property
def param_groups(self) -> Any: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
@runtime_checkable
class SchedulerProtocol(Protocol):
def step(self): ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
def get_last_lr(self): ...

View File

@ -1,6 +1,9 @@
import io
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
import safetensors.torch as st
import torch
@ -8,70 +11,191 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: Union[str, Path]):
st.save_file(state_dict, str(path))
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: Union[str, Path]):
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
with open(str(path), "r") as f:
return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: Union[str, Path]):
torch.save(obj, str(path))
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
path = Path(path)
if not broadcast or not dist.is_initialized():
return load_safetensors(path)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(path)
specs = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass
class Checkpoint:
def __init__(
self,
state_dict: Dict[str, Any],
epoch: int = 0,
iteration: int = 0,
extra: Optional[Dict[str, Any]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.extra = extra or {}
def save(
self,
save_dir: str,
) -> None:
state_dict: Dict[str, Any] = field(default_factory=dict)
epoch: int = 0
iteration: int = 0
extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
config: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str):
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
rank = get_rank()
if rank == 0:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
}
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
if get_rank() != 0:
return
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra:
torch.save(self.extra, save_path / "extra.pt")
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
**self.meta,
}
save_json(meta, save_path / _META_FILE)
save_json(self.config, save_path / _CONFIG_FILE)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
for key, value in self.extra.items():
save_torch(value, save_path / f"{key}.pt")
@classmethod
def load(
cls,
save_dir: str,
) -> "Checkpoint":
rank = get_rank()
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir)
meta = {}
if rank == 0:
with open(Path(save_dir) / "meta.json", "r") as f:
meta = json.load(f)
meta = load_json(save_path / _META_FILE, broadcast)
config = load_json(save_path / _CONFIG_FILE, broadcast)
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = None
extra_path = save_path / "extra.pt"
if extra_path.exists():
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
extra = {}
for f in sorted(save_path.iterdir()):
if f.suffix == ".pt":
extra[f.stem] = load_torch(f, broadcast=broadcast)
return cls(
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
epoch=meta.get("epoch", 0),
iteration=meta.get("iteration", 0),
extra=extra,
config=config,
)
@classmethod
def load_any(cls, save_dir: str, broadcast: bool = False) -> Optional["Checkpoint"]:
save_path = Path(save_dir)
meta_path = save_path / _META_FILE
weights_path = save_path / _WEIGHTS_FILE
if meta_path.exists():
return cls.load(save_dir, broadcast=broadcast)
if weights_path.exists():
state_dict = load_state_dict(weights_path, broadcast=broadcast)
config = {}
config_path = save_path / _CONFIG_FILE
if config_path.exists():
config = load_json(config_path, broadcast)
return cls(state_dict=state_dict, config=config)
return None

View File

@ -1,13 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from jinja2 import Template
# Message type for chat messages
type MessageType = Dict[str, Any]
@dataclass
class ChatTemplate:
"""A chat template with Jinja2 rendering support.
@ -15,23 +12,24 @@ class ChatTemplate:
name: Unique identifier for the template.
template_str: Jinja2 template string.
description: Optional description.
default_variables: Optional dictionary of default variable values
that will be passed to the template if not overridden during rendering.
default_variables: Optional dictionary of default variable values.
special_tokens: Optional dictionary mapping token names to their string values.
These tokens are automatically added to the template variables.
"""
name: str
template_str: str
description: str = ""
default_variables: Dict[str, Any] = None
special_tokens: Dict[str, str] = None
def __post_init__(self):
if self.default_variables is None:
self.default_variables = {}
if self.special_tokens is None:
self.special_tokens = {}
def __init__(
self,
name: str = "",
template_str: str = "",
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
):
self.name = name
self.template_str = template_str
self.description = description
self.default_variables = default_variables or {}
self.special_tokens = special_tokens or {}
self._compiled: Template = Template(template_str)
@classmethod
def from_string(
@ -43,7 +41,7 @@ class ChatTemplate:
) -> "ChatTemplate":
"""Create a ChatTemplate instance directly from a template string."""
return cls(
name="", # empty name for adhoc templates
name="",
template_str=template_str,
description=description,
default_variables=default_variables,
@ -73,5 +71,4 @@ class ChatTemplate:
if system_prompt is not None:
variables["system_prompt"] = system_prompt
jinja_template = Template(self.template_str)
return jinja_template.render(**variables)
return self._compiled.render(**variables)

View File

@ -51,9 +51,26 @@ class AutoTokenizer:
self.set_chat_template(config["chat_template"])
@classmethod
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory."""
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory.
Raises:
FileNotFoundError: If tokenizer.json is missing.
RuntimeError: If tokenizer failed to initialize.
"""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
if not tokenizer_file.exists():
raise FileNotFoundError(
f"Tokenizer file not found: {tokenizer_file}. "
"A valid tokenizer.json is required."
)
instance = cls(path)
if instance._tokenizer is None:
raise RuntimeError(
f"Failed to load tokenizer from {path}. "
"The tokenizer.json may be corrupted or incompatible."
)
return instance
def save_pretrained(self, save_path: str):

View File

@ -1,3 +1,4 @@
from astrai.trainer.optim import Muon
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import (
@ -9,6 +10,8 @@ from astrai.trainer.trainer import Trainer
__all__ = [
# Main trainer
"Trainer",
# Optimizer
"Muon",
# Strategy factory
"StrategyFactory",
"BaseStrategy",

View File

@ -47,6 +47,10 @@ def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_val_loss(ctx):
return ctx.val_loss
def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model)

143
astrai/trainer/optim.py Normal file
View File

@ -0,0 +1,143 @@
import torch
from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2
X = G
scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale
if steps == 0:
return X
a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * (A @ B)
return X
class Muon(Optimizer):
def __init__(
self,
params,
lr: float = 2e-3,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
ns_steps: int = 5,
adamw_lr: float = None,
adamw_betas: tuple = (0.9, 0.95),
adamw_eps: float = 1e-8,
adamw_wd: float = 0.0,
):
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
ns_steps=ns_steps,
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
adamw_betas=adamw_betas,
adamw_eps=adamw_eps,
adamw_wd=adamw_wd,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_2d, params_1d = [], []
grads_2d, grads_1d = [], []
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2:
params_2d.append(p)
grads_2d.append(p.grad)
else:
params_1d.append(p)
grads_1d.append(p.grad)
if params_2d:
self._muon_update_foreach(params_2d, grads_2d, group)
if params_1d:
self._adamw_update_foreach(params_1d, grads_1d, group)
return loss
def _muon_update_foreach(self, params_2d, grads_2d, group):
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
if wd != 0:
torch._foreach_mul_(params_2d, 1 - lr * wd)
if nesterov:
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
bufs = []
for p, grad in zip(params_2d, grads_2d):
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
bufs.append(state["momentum_buffer"])
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
for p, buf in zip(params_2d, bufs):
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update_foreach(self, params_1d, grads_1d, group):
lr = group["adamw_lr"]
betas = group["adamw_betas"]
eps = group["adamw_eps"]
wd = group["adamw_wd"]
steps: list[int] = []
exp_avgs, exp_avg_sqs = [], []
has_state = []
for p in params_1d:
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
has_state.append(False)
else:
has_state.append(True)
state["step"] += 1
steps.append(state["step"])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
beta1, beta2 = betas
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
bias_correction1 = [1 - beta1**s for s in steps]
bias_correction2 = [1 - beta2**s for s in steps]
if wd != 0:
torch._foreach_mul_(params_1d, 1 - lr * wd)
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
denom = torch._foreach_sqrt(denom)
torch._foreach_add_(denom, eps)
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)

View File

@ -2,7 +2,7 @@
import math
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
from typing import Any, Dict, List
from torch.optim.lr_scheduler import LRScheduler
@ -31,7 +31,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage:
@SchedulerFactory.register("custom")
@ -41,33 +40,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
"""
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
@classmethod
def create(
cls, optimizer, schedule_type: str = "none", **kwargs
) -> "BaseScheduler":
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
"""
return super().create(schedule_type, optimizer, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return cls.list_registered()
# ----------- Scheduler implementations -----------
@ -192,3 +164,66 @@ class SGDRScheduler(BaseScheduler):
self.min_rate = state_dict.pop("min_rate")
self.t_mult = state_dict.pop("t_mult")
super().load_state_dict(state_dict)
@SchedulerFactory.register("wsd")
class WSDScheduler(BaseScheduler):
"""WSD (Warmup-Stable-Decay) scheduler with sqrt cooldown.
warmup_steps: linear warmup from min_rate to 1.0
stable_steps: constant at base_lr
decay_steps: sqrt decay from base_lr to min_rate
min_rate: minimum lr as fraction of base_lr (default 0.0)
"""
def __init__(
self,
optimizer,
warmup_steps: int,
stable_steps: int,
decay_steps: int,
min_rate: float = 0.0,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.stable_steps = stable_steps
self.decay_steps = decay_steps
self.min_rate = min_rate
self.total_steps = warmup_steps + stable_steps + decay_steps
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
if self.last_epoch < self.warmup_steps:
factor = self.last_epoch / max(self.warmup_steps, 1)
return [base_lr * factor for base_lr in self.base_lrs]
offset = self.last_epoch - self.warmup_steps
if offset < self.stable_steps:
return list(self.base_lrs)
decay_ratio = (offset - self.stable_steps) / max(self.decay_steps, 1)
decay_ratio = min(decay_ratio, 1.0)
factor = (1.0 - self.min_rate) * (1.0 - decay_ratio) ** 2 + self.min_rate
return [base_lr * factor for base_lr in self.base_lrs]
def state_dict(self):
state = super().state_dict()
state.update(
{
"warmup_steps": self.warmup_steps,
"stable_steps": self.stable_steps,
"decay_steps": self.decay_steps,
"min_rate": self.min_rate,
"total_steps": self.total_steps,
}
)
return state
def load_state_dict(self, state_dict):
self.warmup_steps = state_dict.pop("warmup_steps")
self.stable_steps = state_dict.pop("stable_steps")
self.decay_steps = state_dict.pop("decay_steps")
self.min_rate = state_dict.pop("min_rate")
self.total_steps = state_dict.pop("total_steps")
super().load_state_dict(state_dict)

View File

@ -1,39 +1,28 @@
"""Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
from typing import Callable, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module:
"""Unwrap DDP wrapper if present to get the original model."""
if isinstance(model, DDP):
return model.module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
def create_ref_model(
model_fn: Callable[[], nn.Module], state_dict: Dict[str, Tensor]
) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict."""
ref_model = model_fn()
ref_model.load_state_dict(state_dict)
ref_model.requires_grad_(False)
ref_model.eval()
return ref_model
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
def move_to_device(batch: Dict[str, Tensor], device: str) -> Dict[str, Tensor]:
"""Move batch tensors to specified device with non-blocking transfer."""
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
@ -43,7 +32,7 @@ def get_logprobs(
input_ids: Tensor,
mask: Tensor,
reduction: str,
):
) -> Tensor:
"""Compute token-wise log probabilities from model outputs.
Args:
@ -81,14 +70,35 @@ def get_logprobs(
return token_logprobs * shifted_mask
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
S = position_ids.size(1)
device = position_ids.device
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
doc_ids = torch.cat(
[
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
boundaries.long().cumsum(dim=1),
],
dim=1,
)
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
return (same_doc & causal).unsqueeze(1)
class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
def __init__(
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
**kwargs,
):
self.model = model
self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs
@abstractmethod
@ -122,32 +132,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
strategy = StrategyFactory.create("custom", model, device)
"""
@classmethod
def _validate_component(cls, strategy_cls: type) -> None:
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
@classmethod
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
model: Model instance for the strategy
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
"""
return super().create(train_type, model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
return cls.list_registered()
# ============== Strategy Classes ==============
# All strategies are registered at class definition time using the decorator
@ -160,7 +144,13 @@ class SEQStrategy(BaseStrategy):
Computes cross-entropy loss for next token prediction.
"""
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
def __init__(
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
label_smoothing: float = 0.0,
**kwargs,
):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing
@ -185,21 +175,31 @@ class SFTStrategy(BaseStrategy):
Applies cross-entropy loss only to tokens where loss_mask is True.
"""
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
def __init__(
self,
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
device: str,
label_smoothing: float = 0.0,
**kwargs,
):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids, loss_mask = (
input_ids, target_ids, position_ids, loss_mask = (
batch["input_ids"],
batch["target_ids"],
batch["position_ids"],
batch["loss_mask"],
)
ignore_index = -100
logits = self.model(input_ids=input_ids)["logits"]
input_mask = make_doc_boundary_mask(position_ids)
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
logits = self.model(
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
@ -228,7 +228,9 @@ class DPOStrategy(BaseStrategy):
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.beta = beta
self.reduction = reduction
@ -282,7 +284,9 @@ class GRPOStrategy(BaseStrategy):
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.clip_eps = clip_eps
self.kl_coef = kl_coef
self.group_size = group_size
@ -292,8 +296,7 @@ class GRPOStrategy(BaseStrategy):
def sync_ref_model(self):
"""Copy current model weights to ref model."""
ref_state = self.model.state_dict()
self.ref_model.load_state_dict(ref_state)
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1

View File

@ -1,15 +1,21 @@
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from tqdm import tqdm
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device, get_rank
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
@ -20,9 +26,12 @@ from astrai.trainer.metric_util import (
ctx_get_grad_std,
ctx_get_loss,
ctx_get_lr,
ctx_get_val_loss,
)
from astrai.trainer.train_context import TrainContext
logger = logging.getLogger(__name__)
@runtime_checkable
class TrainCallback(Protocol):
@ -42,18 +51,15 @@ class TrainCallback(Protocol):
def on_epoch_end(self, context: TrainContext):
"""Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext):
"""Called at the beginning of each step."""
def on_step_end(self, context: TrainContext):
"""Called at the end of each step."""
def on_batch_begin(self, context: TrainContext):
"""Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext):
"""Called at the end of each batch."""
def on_optimizer_step(self, context: TrainContext):
"""Called on every optimizer step (sync step only)."""
def on_error(self, context: TrainContext):
"""Called when an error occurs during training."""
@ -79,54 +85,86 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_end(self, context: TrainContext):
_ = context
def on_optimizer_step(self, context: TrainContext):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@CallbackFactory.register("gradient_checkpointing")
class GradientCheckpointingCallback(TrainCallback):
"""
Activation checkpointing callback trades compute for memory
by recomputing specified module activations during the backward pass.
Args:
modules: Module types to apply checkpointing to.
"""
def __init__(self, modules: Optional[List[type]] = None):
self.modules = tuple(modules) if modules else ()
def _enable(self, module: nn.Module):
if self.modules and isinstance(module, self.modules):
fn = module.forward
module._original_forward = fn
module.forward = lambda *a, **kw: torch_checkpoint(
fn, *a, use_reentrant=False, **kw
)
@staticmethod
def _disable(module: nn.Module):
if hasattr(module, "_original_forward"):
module.forward = module._original_forward
del module._original_forward
def on_train_begin(self, context: TrainContext):
context.model.apply(self._enable)
logger.info("Gradient checkpointing enabled")
def on_train_end(self, context: TrainContext):
context.model.apply(self._disable)
@CallbackFactory.register("checkpoint")
class CheckpointCallback(TrainCallback):
"""
Checkpoint callback for trainer.
"""
extra_keys = ("optimizer", "scheduler")
def __init__(
self,
save_dir: str,
interval: int,
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
extra = self.save_extra_fn(context) if self.save_extra_fn else None
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
)
context.checkpoint.save(save_path)
state_dict = context.executor.unwrap_model(context.model)
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context)
meta = context.config.to_dict()
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=meta,
config=context.model_config,
)
context.checkpoint.save(save_path)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
@ -138,6 +176,15 @@ class CheckpointCallback(TrainCallback):
def on_error(self, context: TrainContext):
self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):
@ -145,8 +192,12 @@ class ProgressBarCallback(TrainCallback):
Progress bar callback for trainer.
"""
def __init__(self, num_epoch: int):
def __init__(
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
):
self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None
@only_on_rank(0)
@ -155,16 +206,18 @@ class ProgressBarCallback(TrainCallback):
context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True,
file=self.file or sys.stdout,
)
@only_on_rank(0)
def on_batch_end(self, context: TrainContext):
self.progress_bar.set_postfix(
{
"loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
}
)
postfix = {
"loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
}
if context.val_loss is not None:
postfix["val_loss"] = f"{context.val_loss:.4f}"
self.progress_bar.set_postfix(postfix)
self.progress_bar.update(1)
@only_on_rank(0)
@ -196,6 +249,7 @@ class MetricLoggerCallback(TrainCallback):
self._metric_funcs = {
"loss": ctx_get_loss,
"lr": ctx_get_lr,
"val_loss": ctx_get_val_loss,
"grad_norm": ctx_get_grad_norm,
"grad_std": ctx_get_grad_std,
"grad_max": ctx_get_grad_max,
@ -205,12 +259,16 @@ class MetricLoggerCallback(TrainCallback):
}
def _get_log_data(self, context: TrainContext):
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
data = {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"epoch": context.epoch,
"iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics},
}
for m in self.metrics:
val = self._metric_funcs[m](context)
if val is not None:
data[m] = val
return data
@only_on_rank(0)
def _add_log(self, log_data):
@ -219,6 +277,7 @@ class MetricLoggerCallback(TrainCallback):
@only_on_rank(0)
def _save_log(self, epoch, iter):
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
log_file.parent.mkdir(parents=True, exist_ok=True)
with open(log_file, "w") as f:
for log in self.log_cache:
@ -239,3 +298,43 @@ class MetricLoggerCallback(TrainCallback):
def on_error(self, context):
self._save_log(context.epoch, context.iteration)
@CallbackFactory.register("validation")
class ValidationCallback(TrainCallback):
def _run_validation(self, context: TrainContext):
context.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in context.val_dataloader:
loss = context.strategy(batch)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
if context.world_size > 1 and dist.is_initialized():
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
avg_loss = loss_tensor.item()
context.val_loss = avg_loss
context.model.train()
step_count = context.iteration // context.config.grad_accum_steps
logger.info(
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_optimizer_step(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config
if cfg.val_step <= 0:
return
step_count = context.iteration // cfg.grad_accum_steps
if step_count % cfg.val_step == 0:
self._run_validation(context)

View File

@ -1,15 +1,18 @@
from dataclasses import dataclass, field
from typing import Callable, Optional, Self
from pathlib import Path
from typing import Any, Dict, Optional, Self
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler
from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.serialization import Checkpoint
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_json
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -18,84 +21,158 @@ class TrainContext:
model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None)
optimizer: OptimizerProtocol = field(default=None)
scheduler: SchedulerProtocol = field(default=None)
checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
model_config: dict = field(default_factory=dict)
executor: BaseExecutor = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
loss: float = field(default=0.0)
val_dataloader: Optional[DataLoader] = field(default=None)
val_loss: Optional[float] = field(default=None)
world_size: int = field(default=1)
rank: int = field(default=0)
kwargs: dict = field(default_factory=dict)
kwargs: Dict[str, Any] = field(default_factory=dict)
class TrainContextBuilder:
def __init__(
self,
config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
self._resume_dir: Optional[str] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._resume_dir = resume_dir
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
cfg = self.config
device = get_current_device()
executor = ExecutorFactory.create(
cfg.parallel_mode,
grad_accum_steps=cfg.grad_accum_steps,
**cfg.executor_kwargs,
)
device = get_current_device()
context.model = context.model.to(device=device)
model = cfg.model_fn()
model = model.to(device=device)
model.embed_tokens.neftune_noise_alpha = cfg.neftune_alpha
if self.config.nprocs > 1 and self.config.parallel_wrapper:
context.model = self.config.parallel_wrapper(context.model)
model_config = {}
if self._resume_dir:
config_path = Path(self._resume_dir) / "config.json"
if config_path.exists():
model_config = load_json(config_path)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
if not model_config and hasattr(model, "config"):
model_config = model.config.to_dict()
context = TrainContext(
model=model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
model_config=model_config,
executor=executor,
)
if self._resume_dir:
checkpoint = Checkpoint.load_any(self._resume_dir)
if checkpoint is not None:
model.load_state_dict(checkpoint.state_dict, strict=False)
if checkpoint.config:
context.model_config = checkpoint.config
context.epoch = checkpoint.epoch or cfg.start_epoch
context.iteration = checkpoint.iteration or cfg.start_batch
context.checkpoint = checkpoint
if cfg.lora is not None:
inject_lora(
model,
r=cfg.lora.r,
alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules),
)
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
train_dataset = cfg.dataset
val_dataset = cfg.val_dataset
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
if val_dataset is None and cfg.val_split is not None:
n_total = len(cfg.dataset)
n_val = max(1, int(n_total * cfg.val_split))
n_train = n_total - n_val
generator = torch.Generator().manual_seed(cfg.random_seed)
train_dataset, val_dataset = random_split(
cfg.dataset, [n_train, n_val], generator=generator
)
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
data_source=train_dataset,
start_epoch=context.epoch,
start_iter=sampler_offset,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
train_dataset,
batch_size=cfg.batch_per_device,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
if val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
cfg.strategy,
model=context.model,
train_type=self.config.strategy,
device=device,
**self.config.extra_kwargs,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs,
)
return context

View File

@ -1,10 +1,8 @@
import logging
from itertools import batched
from typing import List, Optional
from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
@ -26,17 +24,28 @@ class Trainer:
def _get_default_callbacks(self) -> List[TrainCallback]:
cfg = self.train_config
return [
callbacks = [
CallbackFactory.create(
"gradient_checkpointing",
modules=cfg.gradient_checkpointing_modules,
),
CallbackFactory.create(
"checkpoint",
cfg.ckpt_dir,
cfg.ckpt_interval,
),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
return callbacks
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
@ -44,56 +53,56 @@ class Trainer:
if method:
method(context)
def train(self, checkpoint: Optional[Checkpoint] = None):
config = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=config.backend,
world_size=config.nprocs,
master_addr=config.master_addr,
master_port=config.master_port,
device_type=config.device_type,
checkpoint=checkpoint,
def _trainer_loop(self, resume_dir: Optional[str] = None):
context = (
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None) -> Checkpoint:
context = self._build_context(checkpoint)
executor = context.executor
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
accumulation_steps = max(self.train_config.accumulation_steps, 1)
for epoch in range(context.epoch, self.train_config.n_epoch):
for epoch in range(context.epoch, context.config.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for steps in batched(context.dataloader, accumulation_steps):
self._call_callbacks("on_step_begin", context)
step_batch_nums = len(steps)
for batch in steps:
for batch in context.dataloader:
with executor.accumulate(context.model):
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
stand_loss = loss / step_batch_nums
stand_loss.backward()
self._call_callbacks("on_batch_end", context)
self._call_callbacks("on_step_end", context)
context.optimizer.step()
context.optimizer.zero_grad()
if executor.sync_gradients:
self._call_callbacks("on_optimizer_step", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context)
except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True)
logger.error("Training failed: %s", str(e), exc_info=True)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, resume_dir: Optional[str] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
resume_dir=resume_dir,
)

View File

@ -1,12 +1,13 @@
services:
server:
build: .
image: astrai:latest
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cuda
deploy:
resources:
@ -25,13 +26,14 @@ services:
server-cpu:
profiles: [cpu]
build: .
image: astrai:latest
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
- ./checkpoints:/app/checkpoints
command: python -m scripts.tools.server --port 8000 --device cpu
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]

View File

@ -11,7 +11,6 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
def generate_text():
# Load model from pretrained
model = AutoModel.from_pretrained(PARAMETER_ROOT)
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
@ -22,16 +21,15 @@ def generate_text():
model=model,
tokenizer=tokenizer,
)
response = engine.generate(
for token in engine.generate(
prompt=query,
stream=False,
stream=True,
max_tokens=2048,
temperature=0.8,
top_p=0.95,
top_k=50,
)
print(response)
):
print(token, end="", flush=True)
if __name__ == "__main__":

View File

@ -16,6 +16,7 @@ NC='\033[0m' # No Color
IMAGE_NAME="astrai"
IMAGE_TAG="latest"
REGISTRY=""
CONTAINER_ID=""
# Print colored messages
print_info() {
@ -175,6 +176,10 @@ main() {
PORT="$2"
shift 2
;;
--container)
CONTAINER_ID="$2"
shift 2
;;
--gpu)
GPU=true
shift
@ -197,6 +202,7 @@ main() {
echo " --dockerfile FILE Dockerfile path (default: Dockerfile)"
echo " --context PATH Build context (default: .)"
echo " --port PORT Port for run (default: 8000)"
echo " --container ID Container ID for logs"
echo " --gpu Enable GPU support"
echo " --help Show this help message"
echo ""
@ -205,6 +211,7 @@ main() {
echo " $0 build --tag v1.0.0"
echo " $0 run --port 8080"
echo " $0 run --gpu"
echo " $0 logs --container abc123"
echo " $0 push --registry ghcr.io/username"
exit 0
;;
@ -237,7 +244,7 @@ main() {
show_info
;;
logs)
show_logs "$2"
show_logs "$CONTAINER_ID"
;;
"")
print_error "No command specified. Use --help for usage"

View File

@ -0,0 +1,334 @@
"""HumanEval code generation benchmark.
Generates n completions per problem, extracts function bodies, executes
against hidden tests, and computes pass@k.
Usage::
python scripts/tools/evaluate_humaneval.py --param_path ./params \
--data_path HumanEval.jsonl.gz --output results.json \
--num_samples 200 --temperature 0.8 --max_tokens 512
"""
import argparse
import json
import os
import re
from math import prod
from multiprocessing import Process, Queue
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import tqdm
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
HUMANEVAL_URL = (
"https://github.com/openai/human-eval/raw/master/data/HumanEval.jsonl.gz"
)
_STOP_SEQUENCES = [
"\nclass ",
"\ndef ",
"\n# ",
"\nif __name__",
"\nprint(",
"\n\n\n",
]
def _download_humaneval(data_path: str):
if os.path.exists(data_path):
return
import gzip
import urllib.request
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
print(f"Downloading HumanEval from {HUMANEVAL_URL} ...")
tmp = data_path + ".tmp"
urllib.request.urlretrieve(HUMANEVAL_URL, tmp)
with gzip.open(tmp, "rb") as f_in:
with open(data_path, "wb") as f_out:
f_out.write(f_in.read())
os.remove(tmp)
print(f" saved to {data_path}")
def _load_problems(data_path: str) -> List[dict]:
problems = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
problems.append(json.loads(line))
return problems
def _extract_function_body(code: str, entry_point: str) -> Optional[str]:
"""Extract the function body from a completion."""
pattern = rf"def\s+{re.escape(entry_point)}\b[^:]*:"
match = re.search(pattern, code)
if not match:
# Use the full code as-is if we can't find the function
return code
body_start = match.end()
lines = code[body_start:].split("\n")
body_lines = []
started = False
for line in lines:
stripped = line.rstrip()
if not stripped and not started:
continue
if not stripped and started:
body_lines.append("")
continue
if not started:
started = True
if stripped.lstrip() == stripped and started:
break
body_lines.append(stripped)
body = "\n".join(body_lines)
if not body.strip():
return None
return body
def _trim_stop_sequences(text: str) -> str:
for stop in _STOP_SEQUENCES:
idx = text.find(stop)
if idx != -1:
text = text[:idx]
return text
def _execute_code(problem: dict, completion: str, timeout: float = 3.0) -> bool:
"""Run the completion against hidden tests in a subprocess."""
def _worker(queue, full_code):
try:
namespace = {}
exec(full_code, namespace)
check = namespace.get("check")
if check is None:
queue.put(False)
return
check(namespace.get(problem["entry_point"]))
queue.put(True)
except Exception:
queue.put(False)
full_code = problem["prompt"] + completion + "\n" + problem["test"]
queue: Queue = Queue()
proc = Process(target=_worker, args=(queue, full_code))
proc.start()
proc.join(timeout)
if proc.is_alive():
proc.terminate()
proc.join()
return False
try:
return queue.get_nowait()
except Exception:
return False
def _pass_at_k(n: int, c: int, k: int) -> float:
"""Unbiased estimator of pass@k."""
if n - c < k:
return 1.0
return 1.0 - float(prod(1.0 - k / np.arange(n - c + 1, n + 1)))
def _deduplicate(completions: List[str]) -> List[str]:
seen = set()
unique = []
for c in completions:
if c not in seen:
seen.add(c)
unique.append(c)
return unique
def _generate(
engine: InferenceEngine,
prompt: str,
num_samples: int,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
batch_size: int,
) -> List[str]:
batches = [prompt] * min(batch_size, num_samples)
completions = []
remaining = num_samples
while remaining > 0:
current = min(batch_size, remaining)
batch_prompts = batches[:current]
outputs = engine.generate(
prompt=batch_prompts,
stream=False,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
if isinstance(outputs, str):
outputs = [outputs]
completions.extend(outputs)
remaining -= current
return _deduplicate(completions)
def evaluate(
engine: InferenceEngine,
problems: List[dict],
num_samples: int,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
batch_size: int,
k_values: Tuple[int, ...] = (1, 10, 100),
) -> Dict:
results = {}
all_pass_at_k = {k: [] for k in k_values}
for problem in tqdm.tqdm(problems, desc="HumanEval", unit="problem"):
task_id = problem["task_id"]
prompt = problem["prompt"]
entry_point = problem["entry_point"]
raw_completions = _generate(
engine,
prompt,
num_samples,
max_tokens,
temperature,
top_p,
top_k,
batch_size,
)
completions = []
for raw in raw_completions:
trimmed = _trim_stop_sequences(raw)
body = _extract_function_body(trimmed, entry_point)
if body:
completions.append(body)
passed = 0
for comp in completions:
if _execute_code(problem, comp):
passed += 1
n = len(completions)
c = passed
result = {"task_id": task_id, "n": n, "passed": c}
for k in k_values:
result[f"pass@{k}"] = round(_pass_at_k(n, c, k), 4)
all_pass_at_k[k].append(_pass_at_k(n, c, k))
results[task_id] = result
summary = {}
for k in k_values:
vals = all_pass_at_k[k]
summary[f"pass@{k}"] = round(float(np.mean(vals)), 4)
results["_summary"] = summary
return results
def main():
parser = argparse.ArgumentParser(description="HumanEval benchmark")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_path",
type=str,
default="./humaneval/HumanEval.jsonl",
help="HumanEval JSONL file (auto-download if missing)",
)
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
parser.add_argument(
"--num_samples",
type=int,
default=200,
help="Completions per problem",
)
parser.add_argument(
"--max_tokens", type=int, default=512, help="Max generation tokens"
)
parser.add_argument(
"--temperature", type=float, default=0.8, help="Sampling temperature"
)
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
parser.add_argument(
"--batch_size", type=int, default=1, help="Inference batch size"
)
parser.add_argument(
"--problems",
type=int,
nargs="+",
default=None,
help="Specific problem indices (0-based)",
)
args = parser.parse_args()
_download_humaneval(args.data_path)
problems = _load_problems(args.data_path)
if args.problems:
problems = [problems[i] for i in args.problems if i < len(problems)]
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
model.to(device="cuda", dtype=torch.bfloat16)
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=args.batch_size,
)
results = evaluate(
engine=engine,
problems=problems,
num_samples=args.num_samples,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
batch_size=args.batch_size,
k_values=(1, 10, 100),
)
summary = results.pop("_summary")
print(f"\n{'=' * 60}")
for k, v in summary.items():
print(f" {k}: {v:.2%}")
print(f"{'=' * 60}")
if args.output:
results["_summary"] = summary
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Results saved to {args.output}")
engine.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,293 @@
"""IFD (Instruction Following Difficulty) data quality scoring.
Computes IFD scores for instruction-response pairs to guide data selection.
IFD = conditional_NLL / unconditional_NLL, where:
- conditional_NLL: average CE loss on response tokens given instruction context
- unconditional_NLL: average CE loss on response tokens alone
Higher IFD (close to 1) = instruction provides less help = harder sample.
Lower IFD (close to 0) = instruction provides strong guidance = easy sample.
IFD > 1 = instruction misleads the model = likely low-quality data.
Usage::
python scripts/eval/ifd.py --param_path ./params \
--input data.jsonl --output data_with_ifd.jsonl \
--instr_key instruction --resp_key response
Disable chat template::
python scripts/eval/ifd.py --param_path ./params \
--input data.jsonl --output data_with_ifd.jsonl \
--instr_key instruction --resp_key response \
--no_chat_template
"""
import argparse
import json
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
def compute_ifd(
model,
tokenizer,
instruction: str,
response: str,
device: str,
max_len: int = 2048,
use_chat_template: bool = False,
) -> dict:
if use_chat_template:
return _compute_ifd_with_template(
model, tokenizer, instruction, response, device, max_len
)
return _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len)
def _compute_ifd_raw(model, tokenizer, instruction, response, device, max_len) -> dict:
instr_ids = tokenizer.encode(instruction)
resp_ids = tokenizer.encode(response)
if not resp_ids:
return {
"L_cond": None,
"L_uncond": None,
"ifd": None,
"error": "empty response",
}
qa_len = len(instr_ids) + len(resp_ids)
if qa_len > max_len:
overflow = qa_len - max_len
instr_ids = instr_ids[overflow:]
instr_len = len(instr_ids)
resp_len = len(resp_ids)
qa_ids = instr_ids + resp_ids
qa_tensor = torch.tensor([qa_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits_qa = model(qa_tensor)["logits"][0]
resp_logits = logits_qa[instr_len - 1 : -1]
resp_targets = torch.tensor(resp_ids, device=device, dtype=torch.long)
L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item()
resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits_resp = model(resp_tensor)["logits"][0]
unp_logits = logits_resp[:-1]
unp_targets = resp_tensor[0, 1:]
L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item()
ifd = L_cond / L_uncond if L_uncond > 0 else None
return {
"L_cond": round(L_cond, 6),
"L_uncond": round(L_uncond, 6),
"ifd": round(ifd, 6) if ifd is not None else None,
"instr_len": instr_len,
"resp_len": resp_len,
"error": None,
}
def _compute_ifd_with_template(
model, tokenizer, instruction, response, device, max_len
) -> dict:
instr_prefix = tokenizer.apply_chat_template(
[{"role": "user", "content": instruction}],
tokenize=False,
add_generation_prompt=True,
)
full_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": instruction},
{"role": "assistant", "content": response},
],
tokenize=False,
add_generation_prompt=False,
)
full_ids = tokenizer.encode(full_text)
prefix_ids = tokenizer.encode(instr_prefix)
resp_ids = tokenizer.encode(response)
if not resp_ids:
return {
"L_cond": None,
"L_uncond": None,
"ifd": None,
"error": "empty response",
}
if len(full_ids) > max_len:
overflow = len(full_ids) - max_len
full_ids = full_ids[overflow:]
prefix_len = len(prefix_ids) - overflow
prefix_len = max(0, prefix_len)
else:
prefix_len = len(prefix_ids)
cond_tensor = torch.tensor([full_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits_qa = model(cond_tensor)["logits"][0]
resp_start = prefix_len - 1
resp_end = len(full_ids) - 1
if resp_end <= resp_start:
return {
"L_cond": None,
"L_uncond": None,
"ifd": None,
"error": "response truncated entirely",
}
resp_logits = logits_qa[resp_start:resp_end]
resp_targets = torch.tensor(full_ids[prefix_len:], device=device, dtype=torch.long)
L_cond = F.cross_entropy(resp_logits, resp_targets, reduction="mean").item()
resp_tensor = torch.tensor([resp_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits_resp = model(resp_tensor)["logits"][0]
unp_logits = logits_resp[:-1]
unp_targets = resp_tensor[0, 1:]
L_uncond = F.cross_entropy(unp_logits, unp_targets, reduction="mean").item()
ifd = L_cond / L_uncond if L_uncond > 0 else None
return {
"L_cond": round(L_cond, 6),
"L_uncond": round(L_uncond, 6),
"ifd": round(ifd, 6) if ifd is not None else None,
"instr_len": prefix_len,
"resp_len": len(resp_ids),
"error": None,
}
def process_file(
param_path: str,
input_file: str,
output_file: str,
instr_key: str,
resp_key: str,
max_len: int,
use_chat_template: bool = False,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
model.eval()
if use_chat_template and tokenizer._chat_template is None:
raise RuntimeError(
"--use_chat_template specified but tokenizer has no chat template. "
"Add a chat_template to tokenizer_config.json or omit the flag."
)
with open(input_file, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f if line.strip()]
results = []
ifd_values = []
with torch.inference_mode():
for item in tqdm.tqdm(data, desc="Computing IFD", unit="sample"):
instruction = item[instr_key]
response = item[resp_key]
scores = compute_ifd(
model,
tokenizer,
instruction,
response,
device,
max_len,
use_chat_template=use_chat_template,
)
ifd_values.append(scores["ifd"])
results.append({**item, "ifd": scores["ifd"], "ifd_detail": scores})
with open(output_file, "w", encoding="utf-8") as f:
for item in results:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
valid_ifd = [v for v in ifd_values if v is not None]
if valid_ifd:
import statistics
print(f"\n{'=' * 50}")
print(f" Samples: {len(data)}")
print(f" Valid IFD: {len(valid_ifd)}")
print(f" Mean IFD: {statistics.mean(valid_ifd):.4f}")
print(f" Median IFD: {statistics.median(valid_ifd):.4f}")
print(f" Stdev IFD: {statistics.stdev(valid_ifd):.4f}")
print(f" Min IFD: {min(valid_ifd):.4f}")
print(f" Max IFD: {max(valid_ifd):.4f}")
print(f"{'=' * 50}")
print(f"Results saved to {output_file}")
def main():
parser = argparse.ArgumentParser(
description="Compute IFD scores for instruction-response data"
)
parser.add_argument("--param_path", type=str, required=True, help="Model directory")
parser.add_argument("--input", type=str, required=True, help="Input JSONL file")
parser.add_argument("--output", type=str, required=True, help="Output JSONL file")
parser.add_argument(
"--instr_key",
type=str,
default="instruction",
help="Key for instruction field",
)
parser.add_argument(
"--resp_key",
type=str,
default="response",
help="Key for response field",
)
parser.add_argument(
"--max_len",
type=int,
default=2048,
help="Max token length (instruction truncated to fit)",
)
parser.add_argument(
"--no_chat_template",
action="store_true",
default=False,
help="Disable chat template, use raw text concatenation",
)
args = parser.parse_args()
process_file(
args.param_path,
args.input,
args.output,
args.instr_key,
args.resp_key,
args.max_len,
use_chat_template=not args.no_chat_template,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,609 @@
"""IFEval instruction-following evaluation benchmark.
Evaluates model responses against regex-based constraint verifiers.
Supports all IFEval constraint types except language detection.
Usage::
python scripts/tools/evaluate_ifeval.py --param_path ./params \
--data_path ifeval.jsonl --output results.json \
--temperature 0.1 --max_tokens 512
"""
import argparse
import json
import os
import re
import urllib.request
from typing import Callable, Dict, List, Optional
import torch
import tqdm
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
IFEVAL_URL = (
"https://raw.githubusercontent.com/google-research/"
"google-research/master/instruction_following_eval/data/input_data.jsonl"
)
CONSTRAINT_VERIFIERS: Dict[str, Callable[[str, dict], bool]] = {}
def register(instruction_id: str):
def decorator(fn):
CONSTRAINT_VERIFIERS[instruction_id] = fn
return fn
return decorator
@register("keywords:existence")
def check_keyword_existence(response: str, kwargs: dict) -> bool:
for kw in kwargs["keywords"]:
if not re.search(re.escape(kw), response, re.IGNORECASE):
return False
return True
@register("keywords:frequency")
def check_keyword_frequency(response: str, kwargs: dict) -> bool:
keyword = kwargs["keyword"]
frequency = kwargs.get("frequency", 1)
relation = kwargs.get("relation", "at least")
count = len(re.findall(re.escape(keyword), response, re.IGNORECASE))
if relation == "less than":
return count < frequency
return count >= frequency
@register("keywords:forbidden_words")
def check_forbidden_words(response: str, kwargs: dict) -> bool:
for word in kwargs["forbidden_words"]:
if re.search(r"\b" + re.escape(word) + r"\b", response, re.IGNORECASE):
return False
return True
@register("keywords:letter_frequency")
def check_letter_frequency(response: str, kwargs: dict) -> bool:
letter = kwargs["letter"].lower()
frequency = kwargs.get("let_frequency", 1)
relation = kwargs.get("let_relation", "at least")
count = response.lower().count(letter)
if relation == "less than":
return count < frequency
return count >= frequency
@register("detectable_content:number_placeholders")
def check_placeholders(response: str, kwargs: dict) -> bool:
num = kwargs.get("num_placeholders", 1)
placeholders = re.findall(r"\[.*?\]", response)
return len(placeholders) >= num
@register("detectable_content:postscript")
def check_postscript(response: str, kwargs: dict) -> bool:
marker = kwargs.get("postscript_marker", "P.S.")
response_lower = response.lower()
if marker == "P.P.S":
return bool(re.search(r"p\.\s?p\.\s?s", response_lower))
elif marker == "P.S.":
return bool(re.search(r"p\.\s?s\.", response_lower))
else:
return bool(re.search(re.escape(marker.lower()), response_lower))
@register("detectable_format:number_bullet_lists")
def check_bullet_lists(response: str, kwargs: dict) -> bool:
num = kwargs.get("num_bullets", 1)
bullets = re.findall(r"^\s*\*[^\*].*$", response, re.MULTILINE)
dashes = re.findall(r"^\s*-.*$", response, re.MULTILINE)
return len(bullets) + len(dashes) == num
@register("detectable_format:number_highlighted_sections")
def check_highlighted_sections(response: str, kwargs: dict) -> bool:
num = kwargs.get("num_highlights", 1)
highlights = re.findall(r"\*[^\n\*]+\*", response)
count = 0
for h in highlights:
if h.strip("*").strip():
count += 1
return count >= num
@register("detectable_format:multiple_sections")
def check_multiple_sections(response: str, kwargs: dict) -> bool:
splitter = kwargs.get("section_spliter", "Section")
num = kwargs.get("num_sections", 1)
pattern = r"\s?" + re.escape(splitter) + r"\s?\d+\s?"
sections = re.split(pattern, response)
return len(sections) - 1 >= num
@register("detectable_format:title")
def check_title(response: str, kwargs: dict) -> bool:
titles = re.findall(r"<<[^>\n]+>>", response)
for title in titles:
if title.strip("<>").strip():
return True
return False
@register("detectable_format:json_format")
def check_json_format(response: str, kwargs: dict) -> bool:
value = response.strip()
for prefix in ("```json", "```Json", "```JSON", "```"):
if value.lower().startswith(prefix.lower()):
value = value[len(prefix) :].strip()
if value.endswith("```"):
value = value[:-3].strip()
try:
json.loads(value)
return True
except (ValueError, json.JSONDecodeError):
return False
@register("detectable_format:general_punctuation")
def check_general_punctuation(response: str, kwargs: dict) -> bool:
punctuation_blacklist = kwargs.get("punctuation_blacklist", [])
for punct in punctuation_blacklist:
if punct in response:
return False
return True
@register("detectable_format:number_highlighted_words")
def check_highlighted_words(response: str, kwargs: dict) -> bool:
num = kwargs.get("num_highlights", 1)
highlights = re.findall(r"\*[^\s\*][^\*]*[^\s\*]\*", response)
return len(highlights) >= num
@register("startend:end_checker")
def check_end_checker(response: str, kwargs: dict) -> bool:
end_phrase = kwargs["end_phrase"]
return (
response.strip()
.rstrip('"')
.rstrip()
.lower()
.endswith(end_phrase.strip().lower())
)
@register("startend:quotation")
def check_quotation(response: str, kwargs: dict) -> bool:
value = response.strip()
return value.startswith('"') and value.endswith('"')
@register("startend:start_checker")
def check_start_checker(response: str, kwargs: dict) -> bool:
starter = kwargs["starter"]
return bool(re.search(r"^\s*" + re.escape(starter), response, re.MULTILINE))
@register("change_case:english_capital")
def check_english_capital(response: str, kwargs: dict) -> bool:
return response.isupper()
@register("change_case:english_lowercase")
def check_english_lowercase(response: str, kwargs: dict) -> bool:
return response.islower()
@register("change_case:capital_word_frequency")
def check_capital_word_frequency(response: str, kwargs: dict) -> bool:
frequency = kwargs.get("capital_frequency", 1)
relation = kwargs.get("capital_relation", "at least")
capital_words = re.findall(r"\b[A-Z]{2,}\b", response)
count = len(capital_words)
if relation == "less than":
return count < frequency
return count >= frequency
@register("punctuation:no_comma")
def check_no_comma(response: str, kwargs: dict) -> bool:
return "," not in response
def count_words(text: str) -> int:
return len(re.findall(r"\b\w+\b", text))
def count_sentences(text: str) -> int:
text = text.strip()
if not text:
return 0
sentences = re.split(r"(?<=[.!?])\s+", text)
return len([s for s in sentences if s.strip()])
@register("length_constraints:number_words")
def check_number_words(response: str, kwargs: dict) -> bool:
num = kwargs.get("num_words", 100)
relation = kwargs.get("relation", "at least")
cnt = count_words(response)
if relation == "less than":
return cnt < num
return cnt >= num
@register("length_constraints:number_sentences")
def check_number_sentences(response: str, kwargs: dict) -> bool:
num = kwargs.get("num_sentences", 5)
relation = kwargs.get("relation", "at least")
cnt = count_sentences(response)
if relation == "less than":
return cnt < num
return cnt >= num
@register("length_constraints:number_paragraphs")
def check_number_paragraphs(response: str, kwargs: dict) -> bool:
num = kwargs.get("num_paragraphs", 1)
if "***" in response:
paragraphs = re.split(r"\s?\*\*\*\s?", response)
else:
paragraphs = re.split(r"\n\n+", response)
actual = len([p for p in paragraphs if p.strip()])
return actual == num
@register("length_constraints:nth_paragraph_first_word")
def check_nth_paragraph_first_word(response: str, kwargs: dict) -> bool:
num_paragraphs = kwargs.get("num_paragraphs", 1)
nth = kwargs.get("nth_paragraph", 1)
first_word = kwargs.get("first_word", "").lower()
paragraphs = re.split(r"\n\n+", response)
paragraphs = [p.strip() for p in paragraphs if p.strip()]
if len(paragraphs) != num_paragraphs:
return False
if nth > len(paragraphs):
return False
target = paragraphs[nth - 1]
words = target.split()
if not words:
return False
word = words[0].strip().lstrip("'\"").rstrip(".,!?:;\"'")
return word.lower() == first_word
@register("length_constraints:nth_word_checker")
def check_nth_word(response: str, kwargs: dict) -> bool:
nth = kwargs.get("nth_word", 1)
target = kwargs.get("target_word", "").lower()
words = re.findall(r"\b\w+\b", response)
if nth > len(words):
return False
return words[nth - 1].lower() == target
@register("combination:repeat_prompt")
def check_repeat_prompt(response: str, kwargs: dict) -> bool:
prompt = kwargs["prompt_to_repeat"]
return response.strip().lower().startswith(prompt.strip().lower())
@register("combination:two_responses")
def check_two_responses(response: str, kwargs: dict) -> bool:
parts = response.split("******")
valid = [p for p in parts if p.strip()]
if len(valid) != 2:
return False
return valid[0].strip() != valid[1].strip()
def download_ifeval(data_path: str):
if os.path.exists(data_path):
return
os.makedirs(os.path.dirname(data_path) or ".", exist_ok=True)
print(f"Downloading IFEval from {IFEVAL_URL} ...")
tmp = data_path + ".tmp"
urllib.request.urlretrieve(IFEVAL_URL, tmp)
with open(tmp, "rb") as f_in:
content = f_in.read()
with open(data_path, "wb") as f_out:
f_out.write(content)
os.remove(tmp)
print(f" saved to {data_path}")
def load_problems(data_path: str) -> List[dict]:
problems = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
problems.append(json.loads(line))
return problems
def verify_response(response: str, instruction_id: str, kwargs: dict) -> Optional[bool]:
verifier = CONSTRAINT_VERIFIERS.get(instruction_id)
if verifier is None:
return None
try:
return verifier(response, kwargs)
except Exception:
return False
def generate_one(
engine: InferenceEngine,
tokenizer: AutoTokenizer,
prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> str:
formatted = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
output = engine.generate(
prompt=formatted,
stream=False,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
if isinstance(output, list):
return output[0]
return output
def evaluate(
engine: InferenceEngine,
tokenizer: AutoTokenizer,
problems: List[dict],
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
num_samples: int = 1,
) -> Dict:
results = {}
constraint_stats: Dict[str, Dict[str, int]] = {}
total_constraints = 0
total_passed = 0
for problem in tqdm.tqdm(problems, desc="IFEval", unit="problem"):
key = problem["key"]
prompt = problem["prompt"]
instruction_ids = problem["instruction_id_list"]
kwargs_list = problem["kwargs"]
samples = []
for _ in range(num_samples):
response = generate_one(
engine, tokenizer, prompt, max_tokens, temperature, top_p, top_k
)
samples.append(response)
constraint_results = []
passed = 0
verified = 0
for idx, instruction_id in enumerate(instruction_ids):
kwargs = kwargs_list[idx] if idx < len(kwargs_list) else {}
best_pass = False
for response in samples:
result = verify_response(response, instruction_id, kwargs)
if result is None:
continue
if result:
best_pass = True
break
verifier_exists = instruction_id in CONSTRAINT_VERIFIERS
if verifier_exists:
verified += 1
if best_pass:
passed += 1
constraint_results.append(
{
"instruction_id": instruction_id,
"passed": best_pass,
"supported": verifier_exists,
"kwargs": kwargs,
}
)
if verifier_exists:
if instruction_id not in constraint_stats:
constraint_stats[instruction_id] = {
"total": 0,
"passed": 0,
}
constraint_stats[instruction_id]["total"] += 1
if best_pass:
constraint_stats[instruction_id]["passed"] += 1
total_constraints += verified
total_passed += passed
accuracy = passed / verified if verified > 0 else None
results[str(key)] = {
"key": key,
"prompt": prompt,
"response": samples[0],
"num_samples": num_samples,
"num_constraints": len(instruction_ids),
"num_verified": verified,
"num_passed": passed,
"accuracy": round(accuracy, 4) if accuracy is not None else None,
"constraints": constraint_results,
}
overall_accuracy = (
round(total_passed / total_constraints, 4) if total_constraints > 0 else 0.0
)
type_summary = {}
for inst_id, stats in sorted(constraint_stats.items()):
type_summary[inst_id] = {
"total": stats["total"],
"passed": stats["passed"],
"accuracy": round(stats["passed"] / stats["total"], 4)
if stats["total"] > 0
else 0.0,
}
unsupported_count = sum(
1
for p in problems
for iid in p["instruction_id_list"]
if iid not in CONSTRAINT_VERIFIERS
)
results["_summary"] = {
"total_problems": len(problems),
"total_constraints": total_constraints,
"total_passed": total_passed,
"overall_accuracy": overall_accuracy,
"unsupported_constraints": unsupported_count,
"supported_types": sorted(CONSTRAINT_VERIFIERS.keys()),
"per_type_accuracy": type_summary,
}
return results
def main():
parser = argparse.ArgumentParser(description="IFEval benchmark")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_path",
type=str,
default="./ifeval/input_data.jsonl",
help="IFEval JSONL file (auto-download if missing)",
)
parser.add_argument("--output", type=str, default=None, help="Output JSON path")
parser.add_argument(
"--max_tokens", type=int, default=512, help="Max generation tokens"
)
parser.add_argument(
"--temperature",
type=float,
default=0.1,
help="Sampling temperature",
)
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling")
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples per problem (best-of-n scoring)",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Inference batch size"
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="Limit to first N problems (for quick testing)",
)
parser.add_argument(
"--dump_responses",
type=str,
default=None,
help="Path to dump raw model responses (JSONL)",
)
args = parser.parse_args()
download_ifeval(args.data_path)
problems = load_problems(args.data_path)
if args.limit:
problems = problems[: args.limit]
print(f"Loaded {len(problems)} problems")
print(f"Supported constraint types: {len(CONSTRAINT_VERIFIERS)}")
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
model.to(device="cuda", dtype=torch.bfloat16)
model.eval()
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=args.batch_size,
)
results = evaluate(
engine=engine,
tokenizer=tokenizer,
problems=problems,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
num_samples=args.num_samples,
)
summary = results.pop("_summary")
print(f"\n{'=' * 60}")
print(f" Problems: {summary['total_problems']}")
print(f" Constraints: {summary['total_constraints']}")
print(f" Passed: {summary['total_passed']}")
print(f" Accuracy: {summary['overall_accuracy']:.2%}")
print(f" Unsupported: {summary['unsupported_constraints']}")
print(f"{'=' * 60}")
print(f"\nPer-type accuracy:")
for inst_id, stats in sorted(summary["per_type_accuracy"].items()):
print(
f" {inst_id:50s} {stats['accuracy']:.2%} "
f"({stats['passed']}/{stats['total']})"
)
if args.output:
results["_summary"] = summary
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to {args.output}")
if args.dump_responses:
with open(args.dump_responses, "w", encoding="utf-8") as f:
for k, v in results.items():
if k.startswith("_"):
continue
f.write(
json.dumps(
{
"key": v["key"],
"prompt": v["prompt"],
"response": v["response"],
},
ensure_ascii=False,
)
+ "\n"
)
print(f"Responses dumped to {args.dump_responses}")
engine.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,319 @@
"""MMLU evaluation via log-likelihood ranking."""
import argparse
import csv
import json
import os
import shutil
import tarfile
import requests
import torch
import torch.nn.functional as F
import tqdm
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
MMLU_SUBJECTS = [
"abstract_algebra",
"anatomy",
"astronomy",
"business_ethics",
"clinical_knowledge",
"college_biology",
"college_chemistry",
"college_computer_science",
"college_mathematics",
"college_medicine",
"college_physics",
"computer_security",
"conceptual_physics",
"econometrics",
"electrical_engineering",
"elementary_mathematics",
"formal_logic",
"global_facts",
"high_school_biology",
"high_school_chemistry",
"high_school_computer_science",
"high_school_european_history",
"high_school_geography",
"high_school_government_and_politics",
"high_school_macroeconomics",
"high_school_mathematics",
"high_school_microeconomics",
"high_school_physics",
"high_school_psychology",
"high_school_statistics",
"high_school_us_history",
"high_school_world_history",
"human_aging",
"human_sexuality",
"international_law",
"jurisprudence",
"logical_fallacies",
"machine_learning",
"management",
"marketing",
"medical_genetics",
"miscellaneous",
"moral_disputes",
"moral_scenarios",
"nutrition",
"philosophy",
"prehistory",
"professional_accounting",
"professional_law",
"professional_medicine",
"professional_psychology",
"public_relations",
"security_studies",
"sociology",
"us_foreign_policy",
"virology",
"world_religions",
]
def _download_and_extract(url: str, data_dir: str):
tar_path = os.path.join(data_dir, "data.tar")
os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...")
resp = requests.get(url, stream=True, timeout=300)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
with open(tar_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
print("Extracting...")
with tarfile.open(tar_path, "r") as tf:
tf.extractall(data_dir)
os.remove(tar_path)
def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "data")
if os.path.exists(src):
for item in os.listdir(src):
src_item = os.path.join(src, item)
dst_item = os.path.join(data_dir, item)
if os.path.exists(dst_item):
if os.path.isdir(dst_item):
shutil.rmtree(dst_item)
else:
os.remove(dst_item)
os.rename(src_item, dst_item)
os.rmdir(src)
print(f"MMLU data saved to {data_dir}")
def _strip_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :].strip()
return text
def load_csv(path: str) -> list[dict]:
data = []
with open(path, "r", encoding="utf-8") as f:
for row in csv.reader(f):
if len(row) < 6:
continue
if row[0].strip().lower() == "question":
continue
data.append(
{
"question": row[0].strip(),
"A": _strip_prefix(row[1].strip(), "A)"),
"B": _strip_prefix(row[2].strip(), "B)"),
"C": _strip_prefix(row[3].strip(), "C)"),
"D": _strip_prefix(row[4].strip(), "D)"),
"answer": row[5].strip(),
}
)
return data
def build_prompt(
question: str, choices: dict, subject: str, n_shot: int, dev_data: list[dict]
) -> str:
prompt = ""
if n_shot > 0 and dev_data:
prompt = f"The following are multiple choice questions (with answers) about {subject}.\n\n"
for item in dev_data[:n_shot]:
prompt += f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {item[k]}\n"
prompt += f"Answer: {item['answer']}\n\n"
prompt += f"Question: {question}\n"
for k in ("A", "B", "C", "D"):
prompt += f"{k}. {choices[k]}\n"
prompt += "Answer:"
return prompt
def apply_chat(
tokenizer, raw_prompt: str, n_shot: int, dev_data: list[dict] | None
) -> str:
"""Wrap raw MMLU prompt in the model's chat template format.
For few-shot, prepend example Q&A pairs as a second user/assistant exchange.
"""
messages = []
if n_shot > 0 and dev_data:
for item in dev_data[:n_shot]:
q = f"Question: {item['question']}\n"
for k in ("A", "B", "C", "D"):
q += f"{k}. {item[k]}\n"
q += "Answer:"
messages.append({"role": "user", "content": q})
messages.append({"role": "assistant", "content": item["answer"]})
messages.append({"role": "user", "content": raw_prompt})
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def choice_logprob(
model, tokenizer, context_ids: list[int], choice_letter: str, device: str
) -> float:
choice_text = choice_letter
choice_ids = tokenizer.encode(choice_text, add_special_tokens=False)
input_ids = context_ids + choice_ids
max_len = model.config.max_len
if len(input_ids) > max_len:
overflow = len(input_ids) - max_len
input_ids = input_ids[overflow:]
ctx_len = len(input_ids) - len(choice_ids)
else:
ctx_len = len(context_ids)
input_tensor = torch.tensor([input_ids], device=device, dtype=torch.long)
with torch.inference_mode():
logits = model(input_tensor)["logits"][0]
score = 0.0
for i, tid in enumerate(choice_ids):
pos = ctx_len - 1 + i
if pos >= len(logits):
break
score += F.log_softmax(logits[pos], dim=-1)[tid].item()
return score
def evaluate_subject(
model,
tokenizer,
subject: str,
test_data: list[dict],
dev_data: list[dict] | None,
device: str,
n_shot: int,
) -> tuple[float, int, int]:
correct = 0
total = 0
for item in tqdm.tqdm(test_data, desc=f"{subject:40s}", leave=False):
raw_prompt = build_prompt(
item["question"], item, subject, n_shot, dev_data or []
)
context = apply_chat(tokenizer, raw_prompt, n_shot, dev_data or [])
context_ids = tokenizer.encode(context)
scores = {
c: choice_logprob(model, tokenizer, context_ids, c, device)
for c in ("A", "B", "C", "D")
}
if max(scores, key=scores.get) == item["answer"]:
correct += 1
total += 1
return correct / total, correct, total
def main():
parser = argparse.ArgumentParser(description="MMLU evaluation")
parser.add_argument(
"--param_path", type=str, default="./params", help="Model directory"
)
parser.add_argument(
"--data_dir", type=str, default="./mmlu_data", help="MMLU data directory"
)
parser.add_argument("--download", action="store_true", help="Download MMLU data")
parser.add_argument(
"--n_shot", type=int, default=5, help="Few-shot examples (0 for zero-shot)"
)
parser.add_argument(
"--subjects", type=str, nargs="+", help="Specific subjects (default: all)"
)
parser.add_argument("--output", type=str, help="Output JSON path")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"])
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16" if torch.cuda.is_available() else "float32",
help="Torch dtype",
)
args = parser.parse_args()
if args.download or not os.path.exists(args.data_dir):
download_mmlu(args.data_dir)
model = AutoModel.from_pretrained(args.param_path)
tokenizer = AutoTokenizer.from_pretrained(args.param_path)
device = args.device
dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype)
model.eval()
subjects = args.subjects or MMLU_SUBJECTS
results = {}
total_correct = 0
total_questions = 0
for subject in subjects:
dev_path = os.path.join(args.data_dir, "dev", f"{subject}_dev.csv")
test_path = os.path.join(
args.data_dir, args.split, f"{subject}_{args.split}.csv"
)
if not os.path.exists(test_path):
print(f" Skipping {subject}: test file not found")
continue
dev_data = load_csv(dev_path) if os.path.exists(dev_path) else None
test_data = load_csv(test_path)
acc, corr, tot = evaluate_subject(
model, tokenizer, subject, test_data, dev_data, device, args.n_shot
)
results[subject] = {"accuracy": round(acc, 4), "correct": corr, "total": tot}
total_correct += corr
total_questions += tot
print(f" {subject:40s} {acc:.2%} ({corr}/{tot})")
overall = total_correct / total_questions if total_questions else 0
print(f"\n{'=' * 70}")
print(f" Overall: {overall:.2%} ({total_correct}/{total_questions})")
results["_overall"] = {
"accuracy": round(overall, 4),
"correct": total_correct,
"total": total_questions,
}
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()

View File

@ -10,11 +10,11 @@ from astrai.tokenize import AutoTokenizer
def process_file(
model_dir: str, input_file: str, output_file: str, batch_size: int, text_key: str
param_path: str, input_file: str, output_file: str, batch_size: int, text_key: str
):
# Load model and tokenizer
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModel.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(param_path)
model.to(device="cuda", dtype=torch.bfloat16)
with open(input_file, "r", encoding="utf-8") as f:
@ -44,8 +44,8 @@ def process_file(
for seq in batch_encoded:
pad_len = max_len - len(seq)
padded_seq = [tokenizer.pad_id] * pad_len + seq
mask = [False] * pad_len + [True] * len(seq)
padded_seq = seq + [tokenizer.pad_id] * pad_len
mask = [True] * len(seq) + [False] * pad_len
padded_ids.append(padded_seq)
masks.append(mask)
@ -88,7 +88,7 @@ def process_file(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run perplexity with a Khaosz model.")
parser.add_argument(
"--model_dir", type=str, required=True, help="Path to the model directory."
"--param_path", type=str, required=True, help="Path to the model directory."
)
parser.add_argument(
"--input_file", type=str, required=True, help="Path to the input file."

View File

@ -1,13 +1,13 @@
"""Benchmark Transformer with KVCache"""
"""Benchmark AutoRegressiveLM with KVCache"""
from dataclasses import dataclass
from typing import Any, Dict
import torch
from astrai.config import ModelConfig
from astrai.config import AutoRegressiveLMConfig
from astrai.inference import KVCache
from astrai.model.transformer import Transformer
from astrai.model.transformer import AutoRegressiveLM
@dataclass
@ -21,7 +21,7 @@ class BenchmarkResult:
class GenerationBenchmark:
def __init__(
self,
config: ModelConfig,
config: AutoRegressiveLMConfig,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
page_size: int = 128,
@ -29,7 +29,7 @@ class GenerationBenchmark:
self.config = config
self.device = device
self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__":
config = ModelConfig(
config = AutoRegressiveLMConfig(
vocab_size=10000,
dim=1536,
n_heads=24,
@ -230,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark (KVCache)")
print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(

View File

@ -0,0 +1,38 @@
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
import argparse
from astrai.config.preprocess_config import PipelineConfig
from astrai.preprocessing.pipeline import Pipeline
def main():
parser = argparse.ArgumentParser(
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
)
parser.add_argument(
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
)
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
parser.add_argument(
"--config", "-c", required=True, help="Path to pipeline config JSON"
)
parser.add_argument(
"--tokenizer_path",
default="params",
help="Path to tokenizer directory (default: params)",
)
args = parser.parse_args()
config = PipelineConfig.from_file(args.config)
Pipeline(
config=config,
input_paths=args.inputs,
output_dir=args.output_dir,
tokenizer_path=args.tokenizer_path,
).run()
if __name__ == "__main__":
main()

View File

@ -18,7 +18,7 @@ def main():
"--reload", action="store_true", help="Enable auto-reload for development"
)
parser.add_argument(
"--param-path",
"--param_path",
type=Path,
default=None,
help="Path to model parameters (default: project_root/params)",

View File

@ -2,22 +2,19 @@ import argparse
import os
from functools import partial
import safetensors.torch as st
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelConfig, TrainConfig
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import Transformer
from astrai.parallel import get_rank
from astrai.model import AutoRegressiveLM
from astrai.model.components.decoder_block import DecoderBlock
from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
parser.add_argument(
"--train_type",
@ -42,18 +39,20 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU.")
parser.add_argument(
"--accumulation_steps",
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
)
parser.add_argument(
"--grad_accum_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=1000,
help="Number of warmup steps for LR scheduler.",
"--warmup_ratio",
type=float,
default=0.05,
help="Fraction of total steps used for LR warmup.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
@ -68,13 +67,13 @@ def parse_args() -> argparse.Namespace:
"--adamw_beta1",
type=float,
default=0.9,
help="Beta values for AdamW optimizer.",
help="Beta1 for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta values for AdamW optimizer.",
help="Beta2 for AdamW optimizer.",
)
parser.add_argument(
"--adamw_weight_decay",
@ -114,9 +113,15 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--label_smoothing",
type=float,
default=0.1,
default=0.0,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument(
"--gradient_checkpointing",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable activation checkpointing for DecoderBlock modules.",
)
parser.add_argument(
"--ckpt_interval",
@ -130,6 +135,36 @@ def parse_args() -> argparse.Namespace:
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument(
"--val_split",
type=float,
default=None,
help="Ratio to split from training dataset for validation (e.g. 0.05).",
)
parser.add_argument(
"--val_step",
type=int,
default=1000,
help="Number of optimizer steps between validation runs.",
)
parser.add_argument(
"--metrics",
nargs="*",
default=["loss", "lr"],
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
)
parser.add_argument(
"--log_dir",
type=str,
default="checkpoint/logs",
help="Directory for metric logs.",
)
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="Number of batch iterations between metric logs.",
)
parser.add_argument(
"--grpo_sync_interval",
type=int,
@ -143,42 +178,84 @@ def parse_args() -> argparse.Namespace:
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument(
"--master_addr",
type=str,
default="localhost",
help="Master node address for distributed training.",
)
parser.add_argument(
"--master_port",
type=str,
default="29500",
help="Master node port for distributed training.",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
help="Distributed training backend.",
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp", "fsdp"],
help="Parallel training strategy (none, ddp, fsdp).",
)
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
parser.add_argument(
"--start_method",
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
parser.add_argument(
"--neftune_alpha",
type=float,
default=0.0,
help="NEFTune noise alpha (0=disabled, typical: 5.0).",
)
args = parser.parse_args()
return args
def ddp_wrap(model: nn.Module):
local_rank = get_rank()
ddp_model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
static_graph=True,
find_unused_parameters=False,
gradient_as_bucket_view=True,
broadcast_buffers=False,
)
return ddp_model
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs)
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.create(optimizer, **kwargs)
schedule_type = kwargs.pop("schedule_type")
return SchedulerFactory.create(schedule_type, optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
def compute_total_steps(
dataset_len: int,
n_epoch: int,
batch_per_device: int,
nprocs: int,
grad_accum_steps: int,
) -> int:
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
return total_steps
def train(
@ -187,13 +264,18 @@ def train(
data_root_path: str,
max_lr: float,
n_epoch: int,
batch_size: int,
batch_per_device: int,
start_epoch: int,
start_batch: int,
accumulation_steps: int,
warmup_steps: int,
grad_accum_steps: int,
warmup_ratio: float,
ckpt_interval: int,
ckpt_dir: str,
val_split: float,
val_step: int,
metrics: list[str],
log_dir: str,
log_interval: int,
dpo_beta: float,
grpo_clip_eps: float,
grpo_kl_coef: float,
@ -207,36 +289,32 @@ def train(
random_seed: int,
num_workers: int,
pin_memory: bool,
gradient_checkpointing: bool,
window_size: int,
stride: int,
nprocs: int,
parallel_mode: str,
device_type: str,
backend: str,
master_addr: str,
master_port: str,
start_method: str,
neftune_alpha: float,
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config
config = ModelConfig()
config_path = os.path.join(param_path, "config.json")
if os.path.exists(config_path):
config.load(config_path)
config = AutoRegressiveLMConfig.from_file(config_path)
if window_size is None:
window_size = config.max_len
# Create bare Transformer (for training, no tokenizer needed)
model = Transformer(config)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")
if os.path.exists(weights_path):
state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
model = model.to(dtype=torch.bfloat16)
strategy_kwargs = {
"dpo_beta": dpo_beta,
"beta": dpo_beta,
"label_smoothing": label_smoothing,
"clip_eps": grpo_clip_eps,
"kl_coef": grpo_kl_coef,
@ -244,6 +322,12 @@ def train(
"sync_interval": grpo_sync_interval,
}
executor_kwargs = {
"gradient_as_bucket_view": True,
"broadcast_buffers": False,
}
model_fn = partial(create_model, config)
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
@ -260,42 +344,59 @@ def train(
},
)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
total_steps = compute_total_steps(
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
)
warmup_steps = int(warmup_ratio * total_steps)
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": total_steps - warmup_steps,
"warmup_steps": min(warmup_steps, total_steps),
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
},
)
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
train_config = TrainConfig(
model=model,
model_fn=model_fn,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_size=batch_size,
batch_per_device=batch_per_device,
start_epoch=start_epoch,
start_batch=start_batch,
ckpt_interval=ckpt_interval,
accumulation_steps=accumulation_steps,
grad_accum_steps=grad_accum_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=num_workers,
pin_memory=pin_memory,
nprocs=nprocs,
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
backend=backend,
master_addr=master_addr,
master_port=master_port,
parallel_mode=parallel_mode,
device_type=device_type,
start_method=start_method,
val_split=val_split,
val_step=val_step,
metrics=metrics,
log_dir=log_dir,
log_interval=log_interval,
gradient_checkpointing_modules=grad_ckpt_modules,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs,
neftune_alpha=neftune_alpha,
)
trainer = Trainer(train_config)
trainer.train()
trainer.train(resume_dir=param_path)
if __name__ == "__main__":

View File

@ -8,8 +8,8 @@ import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
from astrai.tokenize import AutoTokenizer
@ -104,19 +104,19 @@ def test_tokenizer():
@pytest.fixture(scope="session")
def test_model():
"""Session-scoped small Transformer model, created once."""
config = ModelConfig(
"""Session-scoped small AutoRegressiveLM model, created once."""
config = AutoRegressiveLMConfig(
vocab_size=1000,
dim=16,
n_heads=4,
n_kv_heads=2,
dim_ffn=32,
max_len=1024,
n_layers=4,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model = AutoRegressiveLM(config).to(device=device)
return {
"model": model,
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
json.dump(
{
"vocab_size": 1000,
"dim": 16,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 32,
"max_len": 1024,
"n_layers": 4,
"dim": 8,
"n_heads": 2,
"n_kv_heads": 1,
"dim_ffn": 16,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
},
f,

202
tests/data/conftest.py Normal file
View File

@ -0,0 +1,202 @@
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS_CONFIG = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
}
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
_INSTRUCTION_SECTIONS = [
{"field": "prompt", "action": "mask", "add_special_tokens": True},
{"field": "response", "action": "train"},
]
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
def _build_chat_tokenizer():
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
def make_chat_config():
return PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
def make_dpo_chat_config():
return PipelineConfig(
input=InputConfig(
sources={
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": True}
]
},
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": True}
]
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": True}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_no_template_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{
"field": "prompt",
"action": "mask",
"add_special_tokens": True,
}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)

View File

@ -1,3 +1,4 @@
import os
import tempfile
import torch
@ -35,6 +36,30 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30
def test_checkpoint_with_extra():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
extra = {
"optimizer": optimizer.state_dict(),
"scheduler": {"last_epoch": 5},
}
checkpoint = Checkpoint(
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
loaded = Checkpoint.load(tmpdir)
assert loaded.extra["scheduler"]["last_epoch"] == 5
assert "state" in loaded.extra["optimizer"]
def simple_training():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)

View File

@ -1,4 +1,3 @@
import json
import os
import numpy as np
@ -7,12 +6,11 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
create_storage,
H5Store,
StoreFactory,
detect_format,
load_json,
load_bin,
save_bin,
save_h5,
)
@ -100,6 +98,7 @@ def test_sft_dataset_with_random_data(base_test_env):
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
}
save_h5(test_dir, "sft_data", dummy_data)
@ -157,111 +156,6 @@ def test_dataset_with_custom_stride(base_test_env):
assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
@ -318,37 +212,29 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_store_unloaded_len():
"""Unloaded Store has __len__ == 0"""
store = H5Store()
assert len(store) == 0
assert store.keys == []
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
def test_store_fetch_begin_equals_end(base_test_env):
"""Store.fetch with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
result = dataset.storage.fetch(10, 10, "sequence")
assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError"""
store = H5Store()
with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence")
store.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
@ -367,54 +253,192 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path)
def test_create_storage_invalid_type():
"""create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown storage type"):
create_storage("parquet")
def test_create_store_invalid_type():
"""StoreFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_pretok")
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
save_h5(data_dir, "data", {"sequence": segs})
store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7]
def test_save_load_bin_roundtrip(base_test_env):
"""save_bin + load_bin roundtrip preserves data"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
}
save_bin(test_dir, data)
result = load_bin(test_dir)
assert "sequence" in result
assert "loss_mask" in result
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
def test_mmap_store_load_and_fetch(base_test_env):
"""MmapStore loads bin data and fetches correctly"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
store = StoreFactory.create("bin")
store.load(test_dir)
assert len(store) == 200
assert "sequence" in store.keys
result = store.fetch(10, 20, "sequence")
assert result.tolist() == data["sequence"][0][10:20].tolist()
def test_mmap_dataset_load(base_test_env):
"""DatasetFactory.load auto-detects bin format"""
test_dir = base_test_env["test_dir"]
data = {
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
}
save_bin(test_dir, data)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) > 0
assert dataset.count == 200
assert dataset[0]["input_ids"].shape[0] == 64
def test_normalize_empty_key():
"""_normalize with empty tensor list does not crash"""
store = H5Store()
store._normalize({"sequence": []})
assert len(store) == 0
assert store.keys == ["sequence"]
def test_normalize_mixed_empty_key():
"""_normalize with empty + non-empty keys returns min=0"""
store = H5Store()
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
assert len(store) == 0
assert set(store.keys) == {"sequence", "loss_mask"}
def test_grpo_dataset_dtype(base_test_env):
"""GRPODataset returns correct dtypes"""
test_dir = base_test_env["test_dir"]
seq_len = 100
data = {
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
"masks": [torch.ones(seq_len, dtype=torch.int32)],
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_dtype", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
item = dataset[0]
assert item["prompts"].dtype == torch.long
assert item["responses"].dtype == torch.long
assert item["masks"].dtype == torch.bool
assert item["rewards"].dtype == torch.float32
def test_grpo_dataset_load(base_test_env):
"""GRPODataset loads and returns correct keys"""
test_dir = base_test_env["test_dir"]
seq_len = 200
data = {
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
"masks": [torch.ones(seq_len, dtype=torch.int64)],
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
}
save_h5(test_dir, "grpo_test", data)
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
assert len(dataset) > 0
item = dataset[0]
assert "prompts" in item
assert "responses" in item
assert "masks" in item
assert "rewards" in item
assert item["prompts"].shape[0] == 64
assert item["responses"].shape[0] == 64
def test_detect_format_bin_dir(base_test_env):
"""detect_format returns 'bin' for directory with .bin + meta.json"""
test_dir = base_test_env["test_dir"]
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
assert detect_format(test_dir) == "bin"
def test_store_fetch_multi_key(base_test_env):
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
test_dir = base_test_env["test_dir"]
save_h5(
test_dir,
"multi_key",
{
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
"loss_mask": [torch.ones(100, dtype=torch.int64)],
},
)
store = StoreFactory.create("h5")
store.load(test_dir)
result = store.fetch(10, 20, ["sequence", "loss_mask"])
assert isinstance(result, dict)
assert result["sequence"].shape[0] == 10
assert result["loss_mask"].shape[0] == 10
def test_store_fetch_out_of_bounds(base_test_env):
"""Store.fetch raises ValueError for out-of-bounds indices"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
store = StoreFactory.create("h5")
store.load(test_dir)
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(-1, 10, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(0, 51, "sequence")
with pytest.raises(ValueError, match="out of bounds"):
store.fetch(50, 50, "sequence")
def test_dataset_load_explicit_storage_type(base_test_env):
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
test_dir = base_test_env["test_dir"]
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
assert len(dataset) > 0
assert dataset.count == 200

View File

@ -0,0 +1,396 @@
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
MaskBuilderFactory,
SectionedMaskBuilder,
)
from tests.data.conftest import (
_CHAT_SECTIONS,
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_chat_config,
make_dpo_chat_config,
make_grpo_config,
make_instruction_config,
make_text_config,
)
def test_chat_simple(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "sequence" in result
assert "loss_mask" in result
assert len(result["sequence"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["sequence"])
trained = sum(result["loss_mask"])
assert trained > 0
assert trained < total
def test_chat_mask_only_assistant(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
assert len(ids) == len(mask)
trained = [i for i, m in enumerate(mask) if m == 1]
masked = [i for i, m in enumerate(mask) if m == 0]
assert len(trained) > 0
assert len(masked) > 0
def test_chat_all_masked(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
def test_chat_empty_messages(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_chat_domain_extraction(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_chat_truncation(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, chat_tokenizer)
assert len(result["sequence"]) <= 10
assert len(result["loss_mask"]) == len(result["sequence"])
def test_instruction_basic(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
def test_instruction_prompt_masked(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_instruction_train_on_prompt(test_tokenizer):
config = PipelineConfig(
input=InputConfig(
sections=[
{"field": "prompt", "action": "train", "add_special_tokens": True},
{"field": "response", "action": "mask"},
]
),
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
def test_text_basic(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "sequence" in result
assert len(result["sequence"]) > 0
assert "loss_mask" not in result
def test_text_empty(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_text_truncation(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["sequence"]) <= 3
def test_sectioned_chat(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"]
def test_sectioned_instruction(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
)
builder = SectionedMaskBuilder()
item = {"prompt": "Q: Why?", "response": "A: Because."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
mask = result["loss_mask"]
assert mask[0] == 0
assert mask[-1] == 1
def test_sectioned_text(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "Hello world, this is a test."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "loss_mask" not in result
def test_sectioned_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_factory_registered():
names = MaskBuilderFactory.list_registered()
assert "sectioned" in names
def test_factory_create():
builder = MaskBuilderFactory.create("sectioned")
assert isinstance(builder, SectionedMaskBuilder)
def test_dpo_chat_basic(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
],
"rejected": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "5"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "chosen" in result
assert "rejected" in result
assert "chosen_mask" in result
assert "rejected_mask" in result
assert "domain" in result
assert len(result["chosen"]) == len(result["chosen_mask"])
assert len(result["rejected"]) == len(result["rejected_mask"])
assert sum(result["chosen_mask"]) > 0
assert sum(result["rejected_mask"]) > 0
def test_dpo_chosen_only_trained(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"rejected": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Go away"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert 0 in result["chosen_mask"]
assert 1 in result["chosen_mask"]
assert 0 in result["rejected_mask"]
assert 1 in result["rejected_mask"]
def test_dpo_missing_field_is_none(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
def test_grpo_basic(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "What is 2+2?"}],
"responses": ["4", "The answer is four", "Four", "2+2=4"],
"rewards": [1.0, 0.5, 0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "prompts" in result
assert "responses" in result
assert "masks" in result
assert "rewards" in result
assert len(result["responses"]) == len(result["masks"])
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
def test_grpo_response_tokens_all_trained(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A", "B"],
"rewards": [0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
masks = result["masks"]
assert all(m == 1 for m in masks)
assert len(masks) == len(result["responses"])
def test_grpo_single_reward(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A"],
"rewards": 0.9,
}
result = builder.build(item, config, chat_tokenizer)
assert result["rewards"] == [0.9]

View File

@ -0,0 +1,77 @@
import os
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
)
from tests.data.conftest import (
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_dpo_chat_config,
)
def test_default_values():
config = PipelineConfig()
assert config.version == 1
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
assert config.input.sections is None
def test_from_dict_flat():
data = {
"version": 1,
"input": {
"sections": [{"field": "messages", "action": "$role", "template": True}]
},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.sections == [
{"field": "messages", "action": "$role", "template": True}
]
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip():
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.sections == _INSTRUCTION_SECTIONS
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_file_from_file(temp_dir):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_file(path)
loaded = PipelineConfig.from_file(path)
assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"}
def test_dpo_config_roundtrip(temp_dir):
config = make_dpo_chat_config()
path = os.path.join(temp_dir, "config.json")
config.to_file(path)
loaded = PipelineConfig.from_file(path)
assert loaded.input.sources is not None
assert "chosen" in loaded.input.sources
assert "rejected" in loaded.input.sources
assert loaded.input.sections is None

View File

@ -0,0 +1,349 @@
import json
import os
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from tests.data.conftest import (
_CHAT_SECTIONS,
_CHAT_TEMPLATE,
_INSTRUCTION_SECTIONS,
_SPECIAL_TOKENS_CONFIG,
_TEXT_SECTIONS,
make_dpo_chat_config,
make_grpo_no_template_config,
)
def test_filter_by_length():
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_full_text_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
assert meta["sequence"]["dtype"] == "int32"
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_dtype_override(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
with open(meta_path, "r") as f:
meta = json.load(f)
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "bool"
def test_dpo_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"chosen": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
],
"rejected": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Go away."},
],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_dpo_chat_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "chosen" in meta
assert "rejected" in meta
assert "chosen_mask" in meta
assert "rejected_mask" in meta
assert "sequence" not in meta
def test_grpo_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Question?",
"responses": ["Answer A", "Answer B"],
"rewards": [0.8, 0.3],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_grpo_no_template_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "prompts" in meta
assert "responses" in meta
assert "masks" in meta
assert "rewards" in meta
assert "sequence" not in meta

View File

@ -5,21 +5,22 @@ from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from astrai.inference import app
from astrai.inference import get_app
@pytest.fixture
def client():
"""Provide a test client for the FastAPI app."""
app.state.server_config = {
_app = get_app()
_app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16",
"param_path": None,
"max_batch_size": 1,
"_test": True,
}
app.state.engine = None
return TestClient(app)
_app.state.engine = None
return TestClient(_app)
@pytest.fixture
@ -49,5 +50,5 @@ def mock_engine():
@pytest.fixture
def loaded_model(client, mock_engine):
"""Simulate that the engine is loaded."""
app.state.engine = mock_engine
get_app().state.engine = mock_engine
return mock_engine

View File

@ -0,0 +1,286 @@
"""Unit tests for protocol builders, StopChecker, GenContext, StopInfo."""
import json
from unittest.mock import MagicMock
import pytest
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import GenContext, StopChecker, StopInfo
from astrai.inference.engine import GenerationRequest
def _make_ctx(**kwargs):
defaults = {
"resp_id": "test-123",
"created": 1000,
"model": "test-model",
"prompt_tokens": 10,
"completion_tokens": 5,
}
defaults.update(kwargs)
return GenContext(**defaults)
def _sse_payloads(events):
payloads = []
for chunk in events:
for line in chunk.strip().split("\n"):
if line.startswith("data: "):
try:
payloads.append(json.loads(line[6:]))
except json.JSONDecodeError:
pass
return payloads
class TestStopChecker:
def test_check_finds_match(self):
sc = StopChecker(["stop", "end"])
assert sc.check("hello stop world") == "stop"
def test_check_returns_none_when_no_match(self):
sc = StopChecker(["stop"])
assert sc.check("hello world") is None
def test_check_empty_sequences(self):
sc = StopChecker([])
assert sc.check("hello") is None
class TestGenContext:
def test_defaults(self):
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
assert ctx.completion_tokens == 0
def test_fields_mutable(self):
ctx = GenContext(resp_id="a", created=1, model="m", prompt_tokens=10)
ctx.completion_tokens = 42
assert ctx.completion_tokens == 42
class TestStopInfo:
def test_defaults(self):
s = StopInfo()
assert s.matched is None
assert s.body == ""
assert s.yielded == ""
def test_with_values(self):
s = StopInfo(matched="stop", body="hello stop", yielded="hello ")
assert s.matched == "stop"
assert s.body == "hello stop"
assert s.yielded == "hello "
class TestOpenAIResponseBuilder:
@pytest.fixture
def builder(self):
builder = OpenAIResponseBuilder()
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hello")]
req.stop = None
req.model = "astrai"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hello"
builder.prepare(req, engine)
return builder
def test_prepare_returns_prompt_ctx_stops(self, builder):
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hi")]
req.stop = ["END"]
req.model = "gpt"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hi"
prompt, ctx, stops = builder.prepare(req, engine)
assert prompt == "Hi"
assert ctx.model == "gpt"
assert ctx.prompt_tokens == 0
assert stops == ["END"]
def test_prepare_no_stop_returns_empty_list(self, builder):
req = MagicMock()
req.messages = []
req.stop = None
req.model = "x"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = ""
_, _, stops = builder.prepare(req, engine)
assert stops == []
def test_format_stream_start(self, builder):
ctx = _make_ctx()
events = builder.format_stream_start(ctx)
payloads = _sse_payloads(events)
assert len(payloads) == 1
p = payloads[0]
assert p["object"] == "chat.completion.chunk"
assert p["choices"][0]["delta"]["role"] == "assistant"
assert p["choices"][0]["finish_reason"] is None
def test_format_chunk(self, builder):
events = builder.format_chunk("hello", body="hello")
payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["choices"][0]["delta"]["content"] == "hello"
assert payload["choices"][0]["finish_reason"] is None
def test_format_stream_end(self, builder):
ctx = _make_ctx(completion_tokens=5)
stop = StopInfo(matched="stop")
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
finish = payloads[0]
assert finish["choices"][0]["finish_reason"] == "stop"
usage = payloads[1]
assert usage["completion_tokens"] == 5
assert usage["total_tokens"] == 15
def test_format_response(self, builder):
ctx = _make_ctx()
stop = StopInfo()
resp = builder.format_response(ctx, "hello", stop)
assert resp["object"] == "chat.completion"
assert resp["choices"][0]["message"]["content"] == "hello"
assert resp["usage"]["prompt_tokens"] == 10
class TestAnthropicResponseBuilder:
@pytest.fixture
def builder(self):
builder = AnthropicResponseBuilder()
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hello")]
req.model = "claude"
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hello"
req.system = None
builder.prepare(req, engine)
return builder
def test_prepare_messages(self, builder):
req = MagicMock()
req.messages = [MagicMock(role="user", content="Hi")]
req.model = "claude"
req.system = None
req.stop_sequences = None
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = "Hi"
prompt, ctx, stops = builder.prepare(req, engine)
assert prompt == "Hi"
assert stops == []
def test_prepare_with_stop_sequences(self, builder):
req = MagicMock()
req.messages = []
req.model = "x"
req.stop_sequences = ["stop", "end"]
req.system = None
engine = MagicMock()
engine.tokenizer.apply_chat_template.return_value = ""
_, _, stops = builder.prepare(req, engine)
assert stops == ["stop", "end"]
def test_format_stream_start(self, builder):
ctx = _make_ctx(prompt_tokens=3)
events = builder.format_stream_start(ctx)
payloads = _sse_payloads(events)
assert len(payloads) == 2
assert payloads[0]["type"] == "message_start"
assert payloads[0]["message"]["usage"]["input_tokens"] == 3
assert payloads[1]["type"] == "content_block_start"
def test_format_chunk(self, builder):
events = builder.format_chunk("tok", body="tok")
payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["type"] == "content_block_delta"
assert payload["delta"]["text"] == "tok"
def test_format_stream_end_no_stop(self, builder):
ctx = _make_ctx(completion_tokens=3)
stop = StopInfo()
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# content_block_stop, message_delta, message_stop
types = [p["type"] for p in payloads]
assert types == ["content_block_stop", "message_delta", "message_stop"]
assert payloads[1]["delta"]["stop_reason"] == "end_turn"
def test_format_stream_end_with_stop_trims_and_emits_remaining(self, builder):
ctx = _make_ctx(completion_tokens=7)
stop = StopInfo(
matched="END",
body="Hello world END extra",
yielded="Hello ",
)
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# unyielded delta, content_block_stop, message_delta, message_stop
types = [p["type"] for p in payloads]
assert types == [
"content_block_delta",
"content_block_stop",
"message_delta",
"message_stop",
]
assert payloads[0]["delta"]["text"] == "world "
assert payloads[2]["delta"]["stop_reason"] == "stop_sequence"
assert payloads[2]["delta"]["stop_sequence"] == "END"
def test_format_stream_end_stop_trimmed_already_yielded(self, builder):
ctx = _make_ctx()
stop = StopInfo(
matched="END",
body="Hello END",
yielded="Hello ",
)
events = builder.format_stream_end(ctx, stop)
payloads = _sse_payloads(events)
# No unyielded delta (everything already sent)
types = [p["type"] for p in payloads]
assert types == ["content_block_stop", "message_delta", "message_stop"]
def test_format_response_with_stop_trims_content(self, builder):
ctx = _make_ctx()
stop = StopInfo(matched="STOP", body="text STOP extra", yielded="text ")
resp = builder.format_response(ctx, "text STOP extra", stop)
assert resp["content"][0]["text"] == "text "
assert resp["stop_reason"] == "stop_sequence"
assert resp["stop_sequence"] == "STOP"
def test_format_response_no_stop(self, builder):
ctx = _make_ctx()
stop = StopInfo()
resp = builder.format_response(ctx, "full text", stop)
assert resp["content"][0]["text"] == "full text"
assert resp["stop_reason"] == "end_turn"
class TestGenerationRequestValidation:
def test_valid_params(self):
gr = GenerationRequest(
messages=[{"role": "user", "content": "hi"}],
top_k=50,
top_p=0.9,
temperature=0.7,
)
assert gr.top_k == 50
def test_invalid_top_p_raises(self):
with pytest.raises(ValueError, match="top_p"):
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_p=1.5)
def test_invalid_top_k_raises(self):
with pytest.raises(ValueError, match="top_k"):
GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=-1)
def test_invalid_temperature_raises(self):
with pytest.raises(ValueError, match="temperature"):
GenerationRequest(
messages=[{"role": "user", "content": "hi"}], temperature=-0.1
)
def test_top_k_zero_valid(self):
gr = GenerationRequest(messages=[{"role": "user", "content": "hi"}], top_k=0)
assert gr.top_k == 0

View File

@ -173,3 +173,21 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
"""Tasks whose entire prompt is cached skip the prefill phase."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
scheduler.stop()
assert task_id.startswith("task_")

View File

@ -2,12 +2,12 @@
import pytest
from astrai.inference import app
from astrai.inference import get_app
def test_health_no_model(client):
"""GET /health should return 200 even when engine not loaded."""
app.state.engine = None
get_app().state.engine = None
response = client.get("/health")
assert response.status_code == 200
data = response.json()
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
async def async_gen():
yield "Assistant reply"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
yield "cumulative1"
yield "cumulative2"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
async def async_gen():
yield "Assistant reply"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
yield "cumulative1"
yield "cumulative2"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
async def async_gen():
yield "Reply"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -157,5 +157,60 @@ def test_messages_with_system(client, loaded_model):
assert data["type"] == "message"
def test_chat_completions_stop_sequence(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates at stop sequence."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": False,
"stop": ["X"],
},
)
assert response.status_code == 200
data = response.json()
content = data["choices"][0]["message"]["content"]
assert "X" in content
assert "world" not in content
def test_chat_completions_stop_sequence_stream(client, loaded_model):
"""POST /v1/chat/completions with stop parameter truncates SSE stream."""
async def async_gen():
yield "Hello"
yield "X"
yield "world"
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 100,
"stream": True,
"stop": ["X"],
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
content = response.content.decode("utf-8")
assert "Hello" in content
assert "world" not in content
assert any(
"finish_reason" in line for line in content.split("\n") if "stop" in line
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,691 @@
"""Unit tests for tool call parsers."""
import pytest
from astrai.inference.api.tool_parser import (
_TOOL_CALL_HEAD_RE,
BaseToolParser,
SimpleJsonToolParser,
ToolParserFactory,
_find_partial_tool_call,
_find_tool_calls,
_scan_json,
)
def test_scan_complete_simple():
end, complete = _scan_json('{"key": "value"}', 0)
assert complete is True
assert end == len('{"key": "value"}')
def test_scan_complete_nested():
text = '{"outer": {"inner": 1}}'
end, complete = _scan_json(text, 0)
assert complete is True
assert end == len(text)
def test_scan_incomplete_unclosed():
end, complete = _scan_json('{"key": "value"', 0)
assert complete is False
def test_scan_incomplete_nested():
end, complete = _scan_json('{"outer": {"inner": 1}', 0)
assert complete is False
def test_scan_string_braces_ignored():
text = '{"key": "a{b}c"} extra'
end, complete = _scan_json(text, 0)
assert complete is True
def test_scan_escaped_quote_ignored():
text = r'{"key": "a\"b"}'
end, complete = _scan_json(text, 0)
assert complete is True
def test_scan_deeply_nested():
text = '{"a": {"b": {"c": {"d": {"e": 5}}}}}'
end, complete = _scan_json(text, 0)
assert complete is True
assert end == len(text)
def test_scan_array_with_braces():
text = '{"items": [{"x": 1}, {"x": 2}]}'
end, complete = _scan_json(text, 0)
assert complete is True
assert end == len(text)
def test_scan_code_in_string():
text = '{"fn": "function() { return 1; }"}'
end, complete = _scan_json(text, 0)
assert complete is True
def test_scan_unicode_chars():
text = '{"key": "\u5317\u4eac"}'
end, complete = _scan_json(text, 0)
assert complete is True
def test_find_single_tool_call():
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "get_weather"
assert '"city"' in results[0]["args"]
assert results[0]["complete"] is True
def test_find_text_before_tool_call():
text = 'Some text {"name": "func", "arguments": {}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["start"] > 0
def test_find_multiple_tool_calls():
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
results = _find_tool_calls(text)
assert len(results) == 2
assert results[0]["name"] == "f1"
assert results[1]["name"] == "f2"
def test_find_no_tool_call():
results = _find_tool_calls("Hello, how are you?")
assert len(results) == 0
def test_find_non_tool_json_skipped():
results = _find_tool_calls('{"not_a_tool": true}')
assert len(results) == 0
def test_find_no_arguments_field():
results = _find_tool_calls('{"name": "simple_func"}')
assert len(results) == 1
assert results[0]["name"] == "simple_func"
assert results[0]["args"] == ""
def test_find_deeply_nested_arguments():
text = '{"name": "deep", "arguments": {"a": {"b": {"c": {"d": 4}}}}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "deep"
assert '"d": 4' in results[0]["args"]
def test_find_arguments_with_boolean_and_null():
text = '{"name": "flags", "arguments": {"active": true, "count": 0, "nick": null}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "flags"
assert "true" in results[0]["args"]
assert "null" in results[0]["args"]
def test_find_arguments_with_array():
text = '{"name": "add_items", "arguments": {"items": [1, 2, 3], "name": "list"}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "add_items"
assert "[1, 2, 3]" in results[0]["args"]
def test_find_arguments_with_nested_array_of_objects():
text = (
'{"name": "batch", '
'"arguments": {"rows": [{"id": 1, "val": "a"}, {"id": 2, "val": "b"}]}}'
)
results = _find_tool_calls(text)
assert len(results) == 1
assert '"rows"' in results[0]["args"]
assert '"id": 1' in results[0]["args"]
def test_find_arguments_as_string_not_object():
text = '{"name": "echo", "arguments": "just a string"}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "echo"
assert "just a string" in results[0]["args"]
def test_find_arguments_with_unicode():
text = (
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
)
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "translate"
def test_find_arguments_with_escaped_quotes():
text = '{"name": "format", "arguments": {"template": "he said \\"hello\\""}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert 'he said \\"hello\\"' in results[0]["args"]
def test_find_arguments_with_braces_in_string():
text = '{"name": "eval", "arguments": {"code": "function(x) { return x + 1; }"}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "eval"
assert "function(x) { return x + 1; }" in results[0]["args"]
def test_find_many_properties():
args = ",".join(f'"{chr(97 + i % 26)}" : {i}' for i in range(20))
text = '{"name": "many", "arguments": {' + args + "}}"
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "many"
def test_find_empty_arguments():
results = _find_tool_calls('{"name": "ping", "arguments": {}}')
assert len(results) == 1
assert results[0]["name"] == "ping"
assert results[0]["args"] == ""
def test_find_extracts_correct_arg_start_position():
text = '{"name": "f", "arguments": {"x": 1}}'
results = _find_tool_calls(text)
assert len(results) == 1
json_str = text[results[0]["start"] : results[0]["end"]]
assert json_str == text
def test_partial_with_name():
result = _find_partial_tool_call('{"name": "func", "arguments": {"city"')
assert result is not None
assert result["name"] == "func"
assert result["complete"] is False
def test_partial_with_full_args():
result = _find_partial_tool_call('{"name": "func", "arguments": {"city": "BJ"}}')
assert result is not None
assert result["name"] == "func"
def test_partial_no_match():
assert _find_partial_tool_call("plain text") is None
def test_partial_no_name_yet():
assert _find_partial_tool_call('{"nam') is None
def test_partial_deeply_nested():
result = _find_partial_tool_call('{"name": "deep", "arguments": {"a": {"b": {"c": ')
assert result is not None
assert result["name"] == "deep"
assert '"a"' in result["args"]
def test_partial_array_incomplete():
result = _find_partial_tool_call('{"name": "batch", "arguments": {"items": [1, 2, ')
assert result is not None
assert result["name"] == "batch"
def test_feed_plain_text():
parser = SimpleJsonToolParser()
deltas = parser.feed("Hello")
assert len(deltas) == 1
assert deltas[0]["content"] == "Hello"
def test_feed_incremental_text():
parser = SimpleJsonToolParser()
assert parser.feed("He") == [{"content": "He"}]
assert parser.feed("Hello") == [{"content": "llo"}]
def test_feed_tool_call_name_delta():
parser = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
deltas = parser.feed(text)
tc_deltas = [d for d in deltas if "tool_calls" in d]
assert len(tc_deltas) >= 1
name_delta = tc_deltas[0]["tool_calls"][0]
assert name_delta["function"]["name"] == "get_weather"
assert name_delta["type"] == "function"
assert "id" in name_delta
def test_feed_tool_call_args_streaming():
parser = SimpleJsonToolParser()
d1 = parser.feed('{"name": "f", "arguments": {"x":')
d2 = parser.feed('{"name": "f", "arguments": {"x": "1"}}')
args_deltas = [
d
for batch in (d1, d2)
for d in batch
if "tool_calls" in d
and "function" in d["tool_calls"][0]
and "arguments" in d["tool_calls"][0]["function"]
]
assert len(args_deltas) >= 1
def test_feed_text_before_tool_call():
parser = SimpleJsonToolParser()
text = 'Let me check. {"name": "func", "arguments": {"a": 1}}'
deltas = parser.feed(text)
content_deltas = [d for d in deltas if "content" in d]
assert any("Let me check" in d.get("content", "") for d in content_deltas)
def test_has_tool_calls_false_by_default():
assert SimpleJsonToolParser().has_tool_calls is False
def test_has_tool_calls_true_after_detection():
parser = SimpleJsonToolParser()
parser.feed('{"name": "f", "arguments": {}}')
assert parser.has_tool_calls is True
def test_feed_no_content_when_no_new_text():
parser = SimpleJsonToolParser()
parser.feed("Hello")
assert parser.feed("Hello") == []
def test_feed_multiple_tool_calls():
parser = SimpleJsonToolParser()
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
deltas = parser.feed(text)
tc_deltas = [d for d in deltas if "tool_calls" in d]
names = set()
for batch in tc_deltas:
for tc in batch["tool_calls"]:
if "function" in tc and "name" in tc["function"]:
names.add(tc["function"]["name"])
assert "f1" in names
assert "f2" in names
def test_feed_with_tools_constructor():
tools = [{"type": "function", "function": {"name": "get_weather"}}]
parser = SimpleJsonToolParser(tools=tools, tool_choice="auto")
deltas = parser.feed('{"name": "get_weather", "arguments": {"city": "BJ"}}')
assert len(deltas) > 0
def test_feed_content_after_tool_call_is_not_emitted():
parser = SimpleJsonToolParser()
parser.feed('{"name": "f", "arguments": {}} trailing text')
assert parser.has_tool_calls
def _collect_args_deltas(parser):
args_parts = []
for d in parser.feed(parser._text_buffer):
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "arguments" in fn and fn["arguments"]:
args_parts.append(fn["arguments"])
return args_parts
def _simulate_streaming(parser, text):
all_delta_names = []
all_args_chunks = []
for i in range(1, len(text) + 1):
deltas = parser.feed(text[:i])
for d in deltas:
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "name" in fn:
all_delta_names.append(fn["name"])
if "arguments" in fn and fn["arguments"]:
all_args_chunks.append(fn["arguments"])
return all_delta_names, all_args_chunks
def test_streaming_token_by_token_full_build():
parser = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
names, args_chunks = _simulate_streaming(parser, text)
assert "get_weather" in names
joined_args = "".join(args_chunks)
assert '"city"' in joined_args
assert "Beijing" in joined_args
def test_streaming_token_by_token_text_then_tool():
parser = SimpleJsonToolParser()
parts = [
"I'll ",
"check ",
"that. ",
'{"',
'name": "search", ',
'"arguments": {"q": "hello"}}',
]
body = ""
content_chunks = []
tool_names = []
for part in parts:
body += part
deltas = parser.feed(body)
for d in deltas:
if "content" in d:
content_chunks.append(d["content"])
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "name" in fn:
tool_names.append(fn["name"])
full_content = "".join(content_chunks)
assert "I'll check that." in full_content
assert "search" in tool_names
def test_streaming_multiple_tool_calls_incremental():
parser = SimpleJsonToolParser()
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
names, _ = _simulate_streaming(parser, text)
assert names[0] == "f1"
assert "f2" in names
def test_streaming_deeply_nested_args():
parser = SimpleJsonToolParser()
text = '{"name": "deep", "arguments": {"a": {"b": {"c": 42}}}}'
_, args_chunks = _simulate_streaming(parser, text)
joined = "".join(args_chunks)
assert '"c": 42' in joined
def test_streaming_args_with_unicode():
parser = SimpleJsonToolParser()
text = (
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
)
_, args_chunks = _simulate_streaming(parser, text)
joined = "".join(args_chunks)
assert "\u4f60\u597d" in joined
def test_streaming_args_with_array():
parser = SimpleJsonToolParser()
text = '{"name": "add", "arguments": {"items": [1, 2, 3]}}'
_, args_chunks = _simulate_streaming(parser, text)
joined = "".join(args_chunks)
assert "[1, 2, 3]" in joined
def test_streaming_empty_arguments():
parser = SimpleJsonToolParser()
text = '{"name": "ping", "arguments": {}}'
deltas = parser.feed(text)
tc_deltas = [d for d in deltas if "tool_calls" in d]
assert len(tc_deltas) >= 1
name_delta = tc_deltas[0]["tool_calls"][0]
assert name_delta["function"]["name"] == "ping"
assert "arguments" in name_delta["function"]
def test_streaming_args_diff_only_emits_new_bytes():
parser = SimpleJsonToolParser()
step1 = parser.feed('{"name": "f", "arguments": {"city": "Bei')
step2 = parser.feed('{"name": "f", "arguments": {"city": "Beijing"}}')
all_args = []
for step in (step1, step2):
for d in step:
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "arguments" in fn and fn["arguments"]:
all_args.append(fn["arguments"])
joined = "".join(all_args)
assert "city" in joined
assert "Beijing" in joined
assert joined.startswith('"city":')
assert all_args[0] != all_args[1]
def test_streaming_distinct_tool_call_ids():
parser = SimpleJsonToolParser()
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
all_ids = []
for i in range(1, len(text) + 1):
deltas = parser.feed(text[:i])
for d in deltas:
if "tool_calls" in d:
for tc in d["tool_calls"]:
if "id" in tc:
all_ids.append(tc["id"])
unique = list(dict.fromkeys(all_ids))
assert len(unique) == 2
def test_parse_complete_basic():
parser = SimpleJsonToolParser()
body = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
result = parser.parse_complete(body)
assert result is not None
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
def test_parse_complete_no_tool_call():
assert SimpleJsonToolParser().parse_complete("Hello world") is None
def test_parse_complete_with_content():
parser = SimpleJsonToolParser()
result = parser.parse_complete('Prefix text. {"name": "f", "arguments": {}}')
assert result is not None
assert result["content"] == "Prefix text."
def test_parse_complete_multiple_tool_calls():
parser = SimpleJsonToolParser()
body = (
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
'{"name": "get_time", "arguments": {"tz": "Asia/Shanghai"}}'
)
result = parser.parse_complete(body)
assert result is not None
assert len(result["tool_calls"]) == 2
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
assert result["tool_calls"][1]["function"]["name"] == "get_time"
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
assert "Asia/Shanghai" in result["tool_calls"][1]["function"]["arguments"]
def test_parse_complete_complex_real_world():
parser = SimpleJsonToolParser()
body = (
'{"name": "send_email", '
'"arguments": {'
'"to": ["a@b.com", "c@d.com"], '
'"cc": null, '
'"subject": "Hello World", '
'"body": "This is a test email.", '
'"priority": 1, '
'"attachments": false'
"}}"
)
result = parser.parse_complete(body)
assert result is not None
tc = result["tool_calls"][0]
assert tc["function"]["name"] == "send_email"
args = tc["function"]["arguments"]
assert '"to"' in args
assert "a@b.com" in args
assert "null" in args
assert "false" in args
def test_parse_complete_content_with_multiple_tool_calls():
parser = SimpleJsonToolParser()
body = (
"I will do two things. "
'{"name": "f1", "arguments": {"a": 1}}'
'{"name": "f2", "arguments": {"b": 2}}'
)
result = parser.parse_complete(body)
assert result is not None
assert result["content"] == "I will do two things."
assert len(result["tool_calls"]) == 2
def test_parse_complete_no_arguments_field():
parser = SimpleJsonToolParser()
result = parser.parse_complete('{"name": "ping"}')
assert result is not None
assert result["tool_calls"][0]["function"]["name"] == "ping"
assert result["tool_calls"][0]["function"]["arguments"] == ""
def test_parse_complete_content_is_none_when_pure_tool_call():
parser = SimpleJsonToolParser()
result = parser.parse_complete('{"name": "f", "arguments": {"x": 1}}')
assert result is not None
assert result["content"] is None
def test_parse_complete_tool_calls_have_ids():
parser = SimpleJsonToolParser()
result = parser.parse_complete(
'{"name": "f1", "arguments": {}}{"name": "f2", "arguments": {}}'
)
assert result is not None
ids = [tc["id"] for tc in result["tool_calls"]]
assert len(ids) == 2
assert all(isinstance(i, str) and i.startswith("call_") for i in ids)
assert ids[0] != ids[1]
def test_feed_then_parse_complete_same_instance():
parser = SimpleJsonToolParser()
parser.feed('{"name": "get_weather", "arguments": {"city": "Beijing"}}')
result = parser.parse_complete(
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
)
assert result is not None
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
assert parser.has_tool_calls
def test_pattern_matches_basic():
assert _TOOL_CALL_HEAD_RE.search('{"name": "f"}')
def test_pattern_matches_with_whitespace():
assert _TOOL_CALL_HEAD_RE.search('{ "name" : "f"}')
def test_pattern_no_match_without_name():
assert _TOOL_CALL_HEAD_RE.search('{"other": 1}') is None
def test_pattern_match_mid_text():
assert _TOOL_CALL_HEAD_RE.search('prefix {"name": "f", "args": {}}') is not None
def test_pattern_name_at_start():
assert _TOOL_CALL_HEAD_RE.match('{"name": "f"}')
def test_pattern_leading_whitespace():
assert _TOOL_CALL_HEAD_RE.search(' {"name": "f"}') is not None
def test_factory_register_and_create():
parser = ToolParserFactory.create("simple_json")
assert isinstance(parser, BaseToolParser)
assert isinstance(parser, SimpleJsonToolParser)
def test_factory_create_passes_tools():
parser = ToolParserFactory.create(
"simple_json", tools=[{"type": "function"}], tool_choice="required"
)
assert parser.tool_choice == "required"
def test_factory_list_registered():
assert "simple_json" in ToolParserFactory.list_registered()
def test_factory_create_with_no_extra_kwargs():
assert isinstance(ToolParserFactory.create("simple_json"), BaseToolParser)
def test_factory_create_with_tools_only():
tools = [
{
"type": "function",
"function": {"name": "test", "parameters": {"type": "object"}},
}
]
parser = ToolParserFactory.create("simple_json", tools=tools)
assert parser.tools == tools
assert parser.tool_choice == "auto"
def test_feed_accepts_token_ids_and_ignores_them():
parser = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
deltas_with = parser.feed(text, current_token_ids=[123, 456], delta_token_ids=[456])
assert len(deltas_with) > 0
def test_feed_token_ids_do_not_affect_parsing():
parser_no_ids = SimpleJsonToolParser()
parser_with_ids = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
result_no = parser_no_ids.feed(text)
result_with = parser_with_ids.feed(
text, current_token_ids=[1, 2, 3], delta_token_ids=[3]
)
assert len(result_no) == len(result_with)
assert len(result_no) > 0
assert (
result_no[0]["tool_calls"][0]["function"]["name"]
== result_with[0]["tool_calls"][0]["function"]["name"]
)
def test_parser_uses_token_ids_for_detection():
class TokenIdParser(BaseToolParser):
def __init__(self, tools=None, tool_choice="auto"):
super().__init__(tools, tool_choice)
self._detections = 0
def feed(self, body, current_token_ids=None, delta_token_ids=None):
if current_token_ids and 999 in current_token_ids:
self._detections += 1
return []
def parse_complete(self, body):
return None
@property
def has_tool_calls(self):
return self._detections > 0
parser = TokenIdParser()
parser.feed("hello", current_token_ids=[1, 999, 3])
assert parser.has_tool_calls

View File

@ -0,0 +1,166 @@
import torch
from astrai.config.model_config import EncoderConfig
from astrai.model.encoder import EmbeddingEncoder
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
def test_encoder_forward_mean():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_cls():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_last():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_with_padding():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_normalize():
config = EncoderConfig(
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
norms = output.norm(p=2, dim=-1)
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
def test_encoder_register():
from astrai.model.automodel import AutoModel
assert AutoModel.is_registered("embedding")
cls = AutoModel.get_component_class("embedding")
assert cls is EmbeddingEncoder
def test_encoder_from_transformer_checkpoint():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
state_dict = model.state_dict()
state_dict["lm_head.weight"] = torch.randn(
config.vocab_size, config.dim, device=device
)
new_model = EmbeddingEncoder(config).to(device=device)
new_model.load_state_dict(state_dict, strict=True)
for key in model.state_dict():
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
def test_encoder_save_load():
import json
import os
import tempfile
import safetensors.torch as st
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
config_path = os.path.join(test_dir, "config.json")
weights_path = os.path.join(test_dir, "model.safetensors")
try:
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
with open(config_path, "w") as f:
json.dump(config_data, f)
config = EncoderConfig.from_file(config_path)
original = EmbeddingEncoder(config)
st.save_file(original.state_dict(), weights_path)
loaded = EmbeddingEncoder(config)
loaded.load_state_dict(st.load_file(weights_path))
for key in original.state_dict():
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
finally:
if os.path.exists(test_dir):
for f in os.listdir(test_dir):
os.remove(os.path.join(test_dir, f))
os.rmdir(test_dir)

View File

@ -0,0 +1,108 @@
import pytest
import torch
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
CONFIGS = [
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
id="gqa_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "mla",
"ffn_type": "mlp",
"kv_lora_rank": 4,
"qk_nope_head_dim": 2,
"qk_rope_head_dim": 2,
},
id="mla_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "moe",
"n_routed_experts": 4,
"n_shared_experts": 1,
"n_activated_experts": 2,
"topk_method": "greedy",
},
id="gqa_moe",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "mlp",
"rope_theta": 100000.0,
},
id="gqa_rope_theta",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
id="gqa_qk_norm",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
id="gqa_tie_weight",
),
]
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs):
config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoRegressiveLM(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert "logits" in output
assert "hidden_states" in output
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
assert not torch.isnan(output["logits"]).any()
assert not torch.isnan(output["hidden_states"]).any()
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs):
config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoRegressiveLM(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert not torch.isnan(output["logits"]).any()

355
tests/module/test_lora.py Normal file
View File

@ -0,0 +1,355 @@
import tempfile
import pytest
import torch
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model import AutoRegressiveLM
from astrai.model.components.linear import Linear
from astrai.model.components.lora import (
LoRAConfig,
LoRALinear,
_collect_lora_info,
_get_lora_count,
inject_lora,
load_lora,
merge_lora,
save_lora,
)
MODEL_KWARGS = dict(
vocab_size=1000,
dim=64,
n_heads=4,
n_kv_heads=2,
dim_ffn=128,
n_layers=2,
max_len=32,
norm_eps=1e-5,
)
def _make_model(**kwargs):
kw = {**MODEL_KWARGS, **kwargs}
config = AutoRegressiveLMConfig(**kw)
model = AutoRegressiveLM(config)
model.eval()
return model
def test_loralinear_init():
base = Linear(64, 128)
lora = LoRALinear(base, r=8, alpha=16)
assert lora.weight is base.weight
assert not lora.weight.requires_grad
assert lora.lora_A.shape == (8, 64)
assert lora.lora_B.shape == (128, 8)
assert lora.scaling == 2.0
assert not lora._merged
assert lora.lora_A.requires_grad
assert lora.lora_B.requires_grad
def test_loralinear_forward_init_zero_delta():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
base_out = base(x)
lora_out = lora(x)
assert torch.allclose(base_out, lora_out)
def test_loralinear_forward_with_delta():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
base_out = base(x)
with torch.no_grad():
lora.lora_B.fill_(1.0)
lora_out = lora(x)
assert not torch.allclose(base_out, lora_out)
def test_loralinear_merge():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
with torch.no_grad():
lora.lora_B.fill_(1.0)
out_before = lora(x).clone()
lora.merge()
out_after = lora(x)
torch.testing.assert_close(out_before, out_after)
assert lora._merged
assert not hasattr(lora, "lora_A")
def test_loralinear_merge_is_idempotent():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
lora = LoRALinear(base, r=2, alpha=2)
with torch.no_grad():
lora.lora_B.fill_(1.0)
lora.merge()
lora.merge()
def test_inject_lora_default_target():
model = _make_model()
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
inject_lora(model, r=4, alpha=8)
lora_count = _get_lora_count(model)
assert lora_count > 0
assert lora_count < n_before
def test_inject_lora_ffn():
model = _make_model()
from astrai.model.components.lora import TARGET_MODULES_FFN
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
assert _get_lora_count(model) > 0
def test_inject_lora_returns_config():
model = _make_model()
cfg = inject_lora(model, r=8, alpha=32)
assert isinstance(cfg, LoRAConfig)
assert cfg.r == 8
assert cfg.alpha == 32
def test_inject_lora_no_matching_targets_warns(caplog):
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
assert "No LoRA layers injected" in caplog.text
def test_inject_lora_preserves_base_output():
model = _make_model()
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_before = model(x)["logits"].clone()
inject_lora(model, r=4, alpha=8)
with torch.no_grad():
out_after = model(x)["logits"]
torch.testing.assert_close(out_before, out_after)
def test_inject_lora_does_not_reinject():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
first_count = _get_lora_count(model)
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
assert _get_lora_count(model) == first_count
def test_inject_lora_adds_new_modules():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
first = _get_lora_count(model)
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
assert _get_lora_count(model) > first
def test_inject_lora_on_mla_model():
model = _make_model(
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
)
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
assert _get_lora_count(model) > 0
def test_inject_lora_on_moe_model():
model = _make_model(
ffn_type="moe",
n_routed_experts=4,
n_shared_experts=1,
n_activated_experts=2,
dim_ffn=32,
)
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
assert _get_lora_count(model) > 0
def test_state_dict_key_format():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
sd = model.state_dict()
assert "layers.0.attention.q_proj.weight" in sd
assert "layers.0.attention.q_proj.lora_A" in sd
assert "layers.0.attention.q_proj.lora_B" in sd
def test_only_lora_params_trainable():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
for name, param in model.named_parameters():
if isinstance(name.split(".")[-1], str) and "lora" in name:
assert param.requires_grad, f"lora param should be trainable: {name}"
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
assert not param.requires_grad, f"injected weight should be frozen: {name}"
def test_state_dict_after_inject_consistent_with_original():
model = _make_model()
sd_before = {k: v for k, v in model.state_dict().items()}
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
sd_after = model.state_dict()
# original keys unchanged
for k in sd_before:
assert k in sd_after
assert sd_before[k].shape == sd_after[k].shape
# new lora keys present
lora_keys = [k for k in sd_after if "lora" in k]
assert len(lora_keys) > 0
def test_save_load_roundtrip():
model = _make_model()
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_src = model(x)["logits"].clone()
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
load_lora(model2, tmpdir)
with torch.no_grad():
out_dst = model2(x)["logits"]
torch.testing.assert_close(out_src, out_dst)
def test_save_after_merge_raises():
model = _make_model()
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
merge_lora(model)
tmpdir2 = tempfile.mkdtemp()
with pytest.raises(RuntimeError, match="No LoRA parameters"):
save_lora(model, tmpdir2, cfg)
def test_load_lora_on_already_injected():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
# load onto already-injected model
load_lora(model2, tmpdir)
assert _get_lora_count(model2) > 0
def test_load_lora_mismatched_r_raises():
model = _make_model()
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
with pytest.raises(RuntimeError, match="size mismatch"):
load_lora(model2, tmpdir) # strict=False, only lora keys
def test_merge_preserves_output():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_before = model(x)["logits"].clone()
merge_lora(model)
with torch.no_grad():
out_after = model(x)["logits"]
torch.testing.assert_close(out_before, out_after)
def test_merge_no_lora_warns(caplog):
model = _make_model()
merge_lora(model)
assert "No LoRA layers to merge" in caplog.text
def test_collect_lora_info():
model = _make_model()
info = _collect_lora_info(model)
assert "q_proj" in info
assert "o_proj" in info
assert "q_proj" in info # each layer has one

View File

@ -6,8 +6,8 @@ import pytest
import safetensors.torch as st
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
@pytest.fixture
@ -17,10 +17,10 @@ def transformer_test_env():
config = {
"vocab_size": 1000,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"dim": 8,
"n_heads": 2,
"n_kv_heads": 1,
"dim_ffn": 16,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
model = Transformer(config)
config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
model = Transformer(config)
config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig().load(config_path)
original_model = Transformer(config)
config = AutoRegressiveLMConfig.from_file(config_path)
original_model = AutoRegressiveLM(config)
st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config)
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
loaded_config = ModelConfig().load(config_path)
model = Transformer(loaded_config)
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)

Some files were not shown because too many files have changed in this diff Show More