Compare commits

...

350 Commits
v1.2.2 ... main

Author SHA1 Message Date
ViperEkura 01ce1fb9e3 refactor : Pipeline 去除去重,ids 重命名为 sequence,泛型透传
- 移除 Pipeline 内置去重逻辑及 dedup_signature 工具函数
- 删除 ProcessingConfig.deduplicate 字段
- builder 返回 'sequence' 替代 'ids',与 dataset 层统一
- pipeline 纯透传,泛型处理任意 key 补齐默认值
2026-05-31 15:14:27 +08:00
ViperEkura 14f83cbdac perf : 预编译 Jinja2 Template,避免每次 render 重新构建 2026-05-31 14:50:16 +08:00
ViperEkura dbe5891201 refactor : 统一 SectionedMaskBuilder,支持可配置 dtype
- 三合一 MaskBuilder,移除 chat/instruction/text,统一为 sections 配置
- OutputConfig 增加 dtype 字段 (per-key,默认 int32)
- 移除 from __future__ import annotations
- 测试适配新配置格式
2026-05-31 14:24:10 +08:00
ViperEkura 2a65c3314c fix : 修复 created 时间戳、bin 多 shard 覆盖与文档遗漏
- openai.py/anthropic.py: created 从 0 改为 int(time.time())
- openai.py: ChatCompletionRequest 不支持参数非默认值时 warning
- pipeline.py: bin 多 shard 使用子目录避免静默覆盖
- storage.py: MmapStore/detect_format 支持多 shard 聚合加载
- architecture.md: mermaid 类图新增 Pipeline 类
- preprocessing.md: 新增多 shard 输出布局与 Python API 示例
- protocol.py: docstring "6 methods" 改为 "5 methods"
2026-05-30 23:03:42 +08:00
ViperEkura 1c2ff05a6d docs : 三轮深度验证修复文档与代码不一致
- architecture.md: 修正 unwrap_model 返回类型、Config Optional 标注、方法签名错误、类名错误
- training.md: 补充 on_error 回调、修正训练循环顺序、补全策略参数、model.safetensors
- inference.md: 修正 GenerationRequest 参数顺序、async 语法、KVCache 描述、temperature 约束
- dataflow.md: 补充 Store.load/fetch 流程、修正可选参数默认值
- README/params: 多 GPU 示例补全 --parallel_mode、文档表补充 preprocessing.md
- preprocessing.md: Chat 模式算法补全 BOS token 步骤
2026-05-30 21:41:06 +08:00
ViperEkura 31ae2deeba refactor : BaseConfig 提供 from_json/to_json,嵌套 config 自动反序列化
- from_json/to_json 上提至 BaseConfig,所有子类自动继承
- _coerce 新增 dict 到 BaseConfig 子类的递归反序列化,消除子类 from_dict 重载
- PipelineConfig 等子类仅声明字段,零样板代码
- 测试 tokenizer 改为自包含 BPE(含 chat template),不依赖 params/ 目录
- 特殊 token 改用 ASCII 字符,兼容所有平台
2026-05-30 21:04:19 +08:00
ViperEkura 69207e2c57 refactor : 基于声明式 JSON 配置的预处理管线重构
- 用工厂注册的 MaskBuilder(chat/instruction/text)替换硬编码的 _transform_* 方法
- mask 规则以 role-to-action 映射声明在配置中,与 chat_template 完全解耦
- 单次编码 + role-span 追踪替代两次编码 + 长度差计算 mask 的方式
- 支持多轮对话训练:所有 assistant 轮次参与训练,而非仅最后一轮
- 新建 astrai.preprocessing 包(builder.py + pipeline.py),删除 astrai/preprocess.py
- CLI 精简为 --config 参数,所有参数通过 PipelineConfig JSON 配置
- 新增 PipelineConfig、InputConfig、ProcessingConfig、OutputConfig dataclass
- 文档:assets/docs/preprocessing.md
- 27 个测试覆盖 mask builder、pipeline、配置序列化、工厂注册
2026-05-30 20:45:09 +08:00
ViperEkura 138c5bcc08 feat : 添加 JSONL 预处理管线
- Pipeline 模板, Reader 加 transform 加 Writer 可组合
- 自动检测 JSONL 格式, 支持 messages 文本 prompt 加 response 三种
- chat 数据通过 apply_chat_template 适配, 自动生成 loss_mask
- 输出对齐 Store 和 DatasetFactory, 直接用于训练
- 默认 bin 格式, CLI 入口 scripts/tools/preprocess.py
2026-05-30 17:12:42 +08:00
ViperEkura a923e0a23a fix : 修复 MMLU 评测脚本数据源和依赖
- 数据源改为 Berkeley data.tar(GitHub zip 不含数据文件)
- urllib 替换为 requests,支持代理下载
- zip 解压替换为 tar,增加目录 flatten 逻辑
- 添加 model.eval() 确保推理模式正确
2026-05-30 16:51:24 +08:00
ViperEkura f521a30b22 fix : FSDP 优化器顺序、温度除零、调度器静默死亡、ref模型设备
- executor: use_orig_params 硬编码 True,FSDP 不替换 Parameter 对象
- strategy: DPO/GRPO ref 模型创建后移到 device
- sample: TemperatureStrategy clamp 1e-8,engine 验证改为 >0
- scheduler: 异常不 re-raise 避免 daemon 静默死亡,stop() 发回调给 waiting 任务
2026-05-29 21:57:44 +08:00
ViperEkura d4451f6afb fix : 并行训练 state_dict 收集与训练/推理并发缺陷
- FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True
- DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict()
- CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict
- strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数
- TrainContextBuilder: 传递 model_fn + executor 到 strategy
- GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重
- TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态
- ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式)
2026-05-29 21:12:52 +08:00
ViperEkura a3275423a4 release : v1.3.7
Features
- FSDP parallel backend with zero-redundancy sharded training
- LoRA fine-tuning module with low-rank adapter injection and persistence
- NTK-Aware RoPE dynamic scaling, extending context window limit
- MMLU evaluation script for standardized model knowledge assessment
- load_json/load_safetensors broadcast mechanism for cross-node distributed loading

Refactors
- Storage layer refactored to Store pattern, removed Fetcher layer, supporting multi-segment data with explicit length
- Training backend refactored to Executor pattern (none/ddp/fsdp), decoupling parallel logic
- Inference protocol layer refactored to Strategy/Builder pattern with independent OpenAI/Anthropic responders
- Unified serialization layer, eliminating scattered I/O paths
- Removed JSONStore from data pipeline, unified to H5/Bin dual format
- Simplified _disable_random_init, moved scheduler into sync block
- Removed -> None return annotations, split FSDP parameters

Fixes
- Disabled DDP static_graph to prevent no_sync/backward conflict under PyTorch 2.7.1
- Checkpoint resume restores optimizer/scheduler state and sampler remaining length
- Unwrap DDP/FSDP on checkpoint save to avoid module. prefix
- start_epoch/start_batch determined by user args, no longer overridden by checkpoint
- Left padding in perplexity.py causing incorrect PPL with batch>1
- Storage multi-segment bug, switched JSON to JSONL
- Early abort on task_extend failure after decode, notify waiting tasks on scheduler crash

Docs
- Synced architecture/training/inference/dataflow/params docs to actual code

Tests
- Completed inference protocol layer unit test coverage
- Added LoRA module tests
- Filled storage layer test gaps
2026-05-29 17:46:03 +08:00
ViperEkura b37c3d000c docs : 同步文档与实际代码
- 移除 JSONStore 引用(该类不存在)
- 修正 Store.load() 和 DatasetFactory.load() 签名(无 tokenizer 参数)
- 修正 TrainContextBuilder.with_resume_dir() 命名
- 修正 Checkpoint config 字段和 meta.json 描述
- 修正 ProtocolHandler.handle() 异步签名
- 修正采样继承图(平行子类,非线性)
- 修正训练循环:回调移入 accumulate 块内
- 更新文档日期至 2026-05-28
2026-05-28 21:01:47 +08:00
ViperEkura 6031020e37 feat : load_json/load_safetensors 支持 broadcast,跨节点分布式加载
- load_json/load_safetensors/load_state_dict 新增 broadcast 参数
- broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank
- load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈
- 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json
- 参数注解 str | Path 统一为 Union[str, Path]
2026-05-28 20:44:58 +08:00
ViperEkura c424dfc293 feat : checkpoint 支持保存 config.json
- Checkpoint.save 写入独立的 config.json(模型架构参数)
- Checkpoint.load 读取 config.json,恢复时覆盖 context.model_config
- TrainContext 新增 model_config 字段,builder 从 resume_dir/config.json 加载
- BaseConfig.to_dict 支持 tuple 和嵌套 dataclass(如 LoRAConfig)
- 删除 _get_meta/_get_config wrapper,直接使用 load_json
2026-05-28 20:21:51 +08:00
ViperEkura 3a28e52e98 fix : start_epoch/start_batch 由用户参数决定,不再被 checkpoint 覆盖 2026-05-28 18:24:22 +08:00
ViperEkura e371908b54 fix : 保存 checkpoint 时 unwrap DDP/FSDP 避免 module. 前缀
- 移除 state_dict_fn 参数
- _save_checkpoint 中先 unwrap_model 再 state_dict()
2026-05-28 18:10:04 +08:00
ViperEkura 7c99da155c refactor: 删除数据流中的 JSONStore
- 移除 JSONStore 及相关函数,训练框架不再依赖 tokenizer
- Store 层只保留 H5Store 和 MmapStore 两种后端
2026-05-28 15:54:26 +08:00
ViperEkura 629e72385b fix : 修复存储层 bug,JSON 切换为 JSONL,补齐测试覆盖
- save_bin/load_bin: save_json/load_json 替换为直接 json.dump/json.load,修复致命 bug
- _normalize: 空 cum 列表 guard,防止 IndexError
- load_json: 改为仅支持 JSONL 逐行解析 (json.loads),移除 .json 支持
- detect_format: 只匹配 *.jsonl,不再匹配 *.json
- save_json: 输出扩展名改为 .jsonl
- GRPODataset.__getitem__: 补齐 .to(dtype=torch.long/bool) 与其他数据集一致
- load_bin: np.memmap mode='r+' 消除 PyTorch 不可写 tensor 警告
- 新增 16 个测试: bin roundtrip, mmap load, 空 key, JSONL 多行/文本, GRPO dtype/load, detect_format bin/jsonl, fetch multi-key/越界, json_to_bin 转换, DPO from JSONL, 显式 storage_type
2026-05-28 15:29:46 +08:00
ViperEkura 0a708fff24 docs : 更新架构文档与 storage 注释,同步 Store 重构
- architecture.md: 类图/关系线全部更新 (BaseStorage→Store, StorageFactory→StoreFactory, 新增 MmapStore)
- architecture.md: 移除 BaseSegmentFetcher/MultiSegmentFetcher 类图与关系
- dataflow.md: 管线加入 .bin 格式, Store._data + _cum 架构
- storage.py: module docstring 改用缩进式注释风格
2026-05-28 14:36:18 +08:00
ViperEkura 6e150ea6d0 refactor : Storage 层重构为 Store,移除 Fetcher 中间层,支持多段数据与显式长度
- 合并 BaseStorage + MultiSegmentFetcher + BaseSegmentFetcher 三层为 Store ABC
- Store._data 直接持有 Dict[str, List[Tensor]],不做强制拼接避免 OOM
- _fetch_key 统一用 bisect 跨段切片,单段多段同一路径
- _length 显式存储(min total across keys),__len__ 返回 O(1)
- MmapStore/H5Store/JSONStore 统一走 _normalize() 注册分段并预计算累积长度
- 所有 I/O 函数 (save_h5/load_h5/json_to_bin 等) 保持不变
2026-05-28 14:23:49 +08:00
ViperEkura cb8dcb97ea refactor : 移除 -> None 返回值标注,拆分 FSDP 参数,新增 mmap 数据集存储
- 删除所有 def 函数 -> None 返回值类型标注
- FSDPExecutor 参数从 **kwargs 拆为显式声明,None 值自动过滤
- 新增 MmapStorage (bin) 存储后端,基于 numpy.memmap 零拷贝加载
- 新增 save_bin/load_bin/json_to_bin 工具函数
- detect_format 支持 bin 格式自动检测
2026-05-28 13:57:06 +08:00
ViperEkura 2d5dc93b3d fix : 修正类型标注与统一 CLI 参数命名
- AutoRegressiveLM.forward 返回类型标注 -> Dict[str, Tensor]
- EmbeddingEncoder 移除冗余 position_ids 自动创建
- CLI 脚本模型目录参数统一为 --param_path
2026-05-27 20:49:44 +08:00
ViperEkura 4145d35e3c refactor: 检查点加载重构,路径替代对象传递
- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串
- Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递
- TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch
- CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁
- serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast
- optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行
- pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
2026-05-27 20:15:29 +08:00
ViperEkura 34c6c45bd6 feat: 初步实现 MMLU 评测脚本
- 支持 few-shot (log-likelihood ranking) 与 zero-shot
- 自动下载 Hendrycks MMLU 数据集
- --device / --dtype 可配置,默认 GPU bf16
2026-05-26 20:23:31 +08:00
ViperEkura e9def84ce7 fix : perplexity.py left padding 导致 batch>1 时 PPL 计算错误 2026-05-26 19:59:57 +08:00
ViperEkura 836e02a166 docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项
- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式
- 修正训练循环 scheduler.step() 在 sync_gradients 块内
- 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联
- --parallel_mode CLI choices 加入 fsdp
- nprocs > 1 且 parallel_mode=none 时 raise error
2026-05-26 19:37:00 +08:00
ViperEkura b558e61f63 refactor: 简化 _disable_random_init,scheduler 移入同步块
- _disable_random_init: enable=False 提前返回,dict 推导替代空字典
- scheduler.step() 移入 sync_gradients 守卫内
2026-05-26 17:05:25 +08:00
ViperEkura 65ab69543b refactor: 统一序列化层,消除分散的 I/O 路径
- Checkpoint 改为 @dataclass,内聚 save/load 方法
- 提取 save_safetensors/load_safetensors/save_json/load_json 共享工具
- 新增 save_model/load_model_config/load_model_weights 模块函数
- automodel 和 lora 统一委托到 serialization 模块
2026-05-26 16:44:40 +08:00
ViperEkura 1d26aa2e93 fix: 禁用DDP static_graph避免PyTorch 2.7.1下no_sync与backward冲突
- static_graph=True时DDP.no_sync() + loss.backward()触发expect_autograd_hooks_内部断言
- PyTorch 2.7.1中no_sync上下文切换与静态图hook状态管理存在兼容性bug
- 将static_graph设为False恢复梯度累积正常执行
- find_unused_parameters保持False(模型无不参与计算的参数)
2026-05-26 15:08:01 +08:00
ViperEkura a548d4553e fix: 断点续训恢复优化器/调度器状态及采样器剩余长度
- 使用Checkpoint.load()替代手动加载model.safetensors,恢复optimizer/scheduler状态
- TrainContextBuilder从checkpoint.extra恢复优化器和调度器state_dict
- ResumableDistributedSampler.__len__返回剩余样本数而非总数
- 训练前对state_dict置空避免mp.spawn pickle 7GB大对象
2026-05-26 13:50:25 +08:00
ViperEkura dd1b39f435 fix: ProgressBar默认输出到stdout
- file参数默认值改为None, 内部用 or sys.stdout 兜底
- 清理inference API中未使用的import (Optional, time, field)
- 删除test_protocol中未使用的ctx变量
2026-05-26 13:27:05 +08:00
ViperEkura 94d6e713e9 test: 补充推理协议层单测覆盖
- StopChecker、GenContext、StopInfo 单测
- OpenAIResponseBuilder / AnthropicResponseBuilder 全部方法
- Anthropic 停止序列裁剪逻辑(含 unyielded 边界)
- GenerationRequest 参数校验含负值边界
- Scheduler prefill 短路验证
2026-05-26 00:21:52 +08:00
ViperEkura 47c37e4876 refactor: 推理协议层重构为策略/建造者模式
- ProtocolHandler 改为具体类,格式化委托给 ResponseBuilder
- 新增 api/protocols/ 目录,含 OpenAIResponseBuilder、AnthropicResponseBuilder
- GenContext、StopInfo 参数对象替代 StreamContext
- 消除 Builder 的实例可变状态(accumulated、_yielded)
- SSE 工具和停止检测收归 ProtocolHandler 统一管理
- prepare() 方法合并原来的 build_prompt、create_response_id
- 参数校验去重:仅 GenerationRequest.init 负责校验
- Prefill 阶段提前短路完全命中的缓存任务
2026-05-26 00:12:57 +08:00
ViperEkura 737585a32a feat: 新增NTK-Aware RoPE缩放支持
- RotaryEmbedding接受rope_scaling配置,自动计算scaled base
- AutoRegressiveLMConfig和EncoderConfig新增rope_scaling字段
2026-05-25 21:22:07 +08:00
ViperEkura a4688021bf feat: 新增LoRA微调模块
- LoRALinear基于register_parameter托管base weight,state_dict路径不变
- inject_lora/merge_lora/save_lora/load_lora完备封装
- 24个单元测试覆盖注入、合并、存取、边界场景
2026-05-25 20:15:31 +08:00
ViperEkura 7df6eb9211 feat: 新增FSDP并行后端
- FSDPExecutor通过**fsdp_kwargs直传FSDP参数
- unwrap_model同时支持DDP和FSDP
- parallel_mode新增fsdp选项
2026-05-25 19:43:14 +08:00
ViperEkura 82a3f2626f docs: 更新文档与代码同步(Executor/训练循环/参数)
- architecture.md: TrainConfig 移除旧 parallel_wrapper/state_dict_fn
- architecture.md: 新增 ExecutorFactory/BaseExecutor/DDPExecutor 等类图
- architecture.md: MLA 新增 use_qk_norm/q_norm/k_norm
- architecture.md: 新增 protocols 命名空间
- training.md: 修复训练循环 hook 名和 scheduler.step 位置
- training.md: 替换 parallel_wrapper 为 parallel_mode/executor.prepare
- training.md: 修复默认回调顺序和 Callback 生命周期表
- params.md: 新增 --parallel_mode 和 --start_method
2026-05-24 22:17:49 +08:00
ViperEkura 7fa69572c0 fix: 测试日志写入临时目录避免冗余文件 2026-05-24 20:54:59 +08:00
ViperEkura 3ab4f237e5 refactor: 重构训练后端为 Executor 模式
- backend.py → executor.py,BaseTrainingBackend → BaseExecutor
- 新增 NoneExecutor(单卡)和 DDPExecutor(DDP,world_size=1 自动降级)
- 新增 GradientState 分离梯度同步状态,AccumOptimizer/AccumScheduler 包裹拦截
- 新增 astrai/protocols.py:OptimizerProtocol/SchedulerProtocol 结构子类型
- TrainContext.backend → executor,TrainConfig 移除 parallel_wrapper/state_dict_fn,新增 parallel_mode/executor_kwargs
- 训练循环用 accumulate() 包裹,on_optimizer_step 命名约定=gate
- scripts/tools/train.py 移除 ddp_wrap/prepare_checkpoint,新增 --parallel_mode
2026-05-24 20:35:44 +08:00
ViperEkura 8cbf3f36e2 feat: 新增训练后端工厂框架
- BaseTrainingBackend 定义 prepare/accumulate/unwrap_model 抽象
- DDPTrainingBackend 支持全部 DDP 参数并通过 BackendFactory 注册
- unwrap_model 改为实例方法,由子类各自实现
2026-05-24 15:15:14 +08:00
ViperEkura 0594ce1017 perf: Muon step 改用 torch._foreach_* 批处理并移除 NS 迭代的冗余 bf16 转换 2026-05-23 19:50:12 +08:00
ViperEkura ff509ff39f fix: decode后task_extend失败时提前中止,scheduler崩溃时通知waiting任务 2026-05-20 19:23:13 +08:00
ViperEkura 785d65436c fix: 修复 to_dict list 类型丢失与 OpenAI stop 参数失效
- to_dict() 增加 list 类型序列化支持,metrics 等字段不再丢失
- OpenAIHandler 补充 get_stop_sequences/on_token,读取 request.stop 并检测停止序列
- 文档类图补充缺失字段、修正关系分类、ChatCompletionRequest 字段增加 Optional
2026-05-19 21:07:07 +08:00
ViperEkura 64be81b7b3 feat: ProgressBarCallback 支持日志行输出到 stdout
- serialization 和 metric_logger 的 timestamp 统一使用 ISO 8601 格式
- ProgressBarCallback 新增 log_interval/file 参数,默认输出到 sys.stdout
2026-05-19 19:12:38 +08:00
ViperEkura 45479b5731 feat: metric 参数通过 TrainConfig 传递
- TrainConfig 新增 log_dir/log_interval/metrics 配置字段

- metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤
2026-05-19 17:50:24 +08:00
ViperEkura e0a3337c22 docs: 更新视频链接 2026-05-19 17:34:01 +08:00
ViperEkura 812238060b fix: docker-compose UID/GID 添加默认值,修复 docker.sh logs 命令 2026-05-18 14:24:00 +08:00
ViperEkura 14b0d56197 fix: 修复无法创建子进程的问题
- mp.start_processes daemon=False
2026-05-18 09:40:32 +08:00
ViperEkura 6c8533f1d2 docs: 修正文档中类名/字段名与代码不一致之处
- ModelConfig → AutoRegressiveLMConfig, Transformer → AutoRegressiveLM
- 新增缺失类: EncoderConfig, EmbeddingEncoder, ConfigFactory, StorageFactory, ValidationCallback
- TrainConfig/TrainContext/ChatCompletionRequest 补充缺失字段
- dataflow.md 中 create_storage → StorageFactory.create
- 示例 --train_type=pt → seq 与代码一致
2026-05-17 21:02:21 +08:00
ViperEkura 2c2697390d feat: 新增 GradientCheckpointingCallback
- TrainConfig.gradient_checkpointing_modules 指定模块类型
- apply 递归遍历,兼容 DDP,不硬编码模型结构
- modules=None 时静默跳过,零开销
2026-05-17 18:21:05 +08:00
ViperEkura 7621f05d3f docs: AdamW beta 默认值改为 (0.9, 0.95)
- 与 Muon 优化器的 AdamW 子优化器保持一致
- 同步更新 train.py/training.md/params.md/README
2026-05-17 17:08:31 +08:00
ViperEkura 10ebd7211f feat: 新增 Muon 优化器
- 2D 参数用 Newton-Schulz 正交化 + Nesterov 动量更新
- 1D 参数用 AdamW 更新
- 支持 lr/momentum/weight_decay/ns_steps 配置
2026-05-17 16:44:03 +08:00
ViperEkura 42a391f0fb feat: 训练中新增验证循环
- TrainConfig 添加 val_dataset/val_step 字段
- TrainContext 添加 val_dataloader/val_loss 字段
- 新增 ValidationCallback 按 step 触发验证 + 训练结束时验证
- ProgressBar/MetricLogger 支持 val_loss 展示与记录
2026-05-17 16:12:42 +08:00
ViperEkura 97c7ac0f4f refactor: Transformer更名为AutoRegressiveLM并新增EmbeddingEncoder
- AutoRegressiveLM 注册名改为 autoregressive_lm
- 新增 EmbeddingEncoder 支持 mean/cls/last pooling
- ModelConfig 增加 pooling_type / normalize_embeddings 字段
- 导入、注释、测试全部同步更新
2026-05-17 15:29:20 +08:00
ViperEkura 8f1b32f2b6 fix: 移除多余 request 参数并增强 tokenizer 健壮性
- 路由和 _get_engine 不再需要 request 参数,直接引用模块级 app
- from_pretrained 增加文件完整性校验,缺 tokenizer.json 则抛 FileNotFoundError
- 移除 from_pretrained 中未使用的 **kwargs
2026-05-17 12:52:18 +08:00
ViperEkura c241a5dcef refactor: 优化并行训练配置与启动管理
- 配置新增 start_method 支持 spawn/fork/forkserver 选择
- 启动方式 mp.spawn 改为 mp.start_processes,支持 daemon=True
- validate() 改为基于 metadata 的反射式校验,不再硬编码字段列表
- CLI 新增 --start_method 参数
2026-05-17 12:33:10 +08:00
ViperEkura 44dab27fdc feat: 数据集加载时校验必填字段
- BaseDataset.required_keys 属性声明所需存储 key
- load() 时自动校验,缺失立即抛 KeyError
- SEQ/SFT/DPO/GRPO 各自声明 required_keys
2026-05-17 11:50:38 +08:00
ViperEkura a44fd22a99 fix: 修复训练与模型参数传递问题
- state_dict_fn 传入 CheckpointCallback,修复多卡 DDP 下 key 前缀丢失
- MLA 增加 use_qk_norm 支持,消除参数静默丢失
- moe_topk_method 统一命名为 topk_method
- checkpoint 回调移至最前
2026-05-17 11:20:13 +08:00
ViperEkura 8a11a7d444 fix: 修复训练脚本两处参数传递问题
- prepare_checkpoint 增加 DDP 判断,单卡时不访问 .module
- dpo_beta 改为 beta,对齐 DPOStrategy 参数名
2026-05-17 11:04:40 +08:00
ViperEkura 1d54491809 refactor: 改用递归子模块 init 替代统一 normal_(0.006)
- Embedding.reset_parameters: normal_(std=0.02)
- Linear.reset_parameters: kaiming_uniform_ + uniform_ bias
- Transformer._init_weights 通过 apply 递归调用子模块 reset_parameters
- 移除全局 normal_(0.006) 覆盖,各模块使用更合适的分布
2026-05-17 10:44:18 +08:00
ViperEkura ad9f4d9cf6 refactor: generate_ar 改用流式输出并去除冗余注释 2026-05-17 10:23:42 +08:00
ViperEkura e1638a7ade fix: 修正AdamW超参数默认值与文档示例
- 交换adamw_beta1/adamw_beta2默认值:beta1=0.95, beta2=0.99
- label_smoothing默认值改为0.05
- 文档示例统一更新:train_type=pt, weight_decay=0.01
- 移除文档中过时的strategy default标注
2026-05-16 22:46:17 +08:00
ViperEkura f91bfee33e refactor: Config序列化统一BaseConfig基类
- 新增astrai/config/base.py,提供to_dict/from_dict基类
- 统一命名:load/save → from_file/to_file
- Checkpoint.meta合并训练配置到meta.json
- sys.stderr.warn → warnings.warn
- from_file改为classmethod
2026-05-16 22:06:39 +08:00
ViperEkura d7a7f570ed refactor: 训练循环改为两重迭代并统一参数命名
- 训练循环从三重(epoch→batched→batch)改为二重(epoch→batch)
- batch_size → batch_per_device, accumulation_steps → grad_accum_steps
- scheduler 移入 step block 对齐 optimizer 更新步
- GradientClippingCallback 改用 on_step_begin 避免零梯度裁剪
- 移除 _train_impl 误导性的 -> Checkpoint 标注
- total_steps 修除为向下取整并精简为一行
- warmup_steps 改为 warmup_ratio (默认0.05)
2026-05-16 21:27:35 +08:00
ViperEkura 7dea929788 refactor: checkpoint 按 HF 方式存独立 .pt 文件,callback 接管恢复
- Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt
- meta.json 新增 timestamp
- CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性
- on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn
2026-05-16 18:29:04 +08:00
ViperEkura 026d1fc33d fix: total_steps 改用 ceiling 匹配实际步数
原公式全用 floor 少算 optimizer step,改用逐层 ceiling
(ceil_div via (a+b-1)//b)对齐 DDP sampler padding +
DataLoader drop_last=False 尾批 + batched 尾组截断。
2026-05-16 17:53:18 +08:00
ViperEkura 7242eedbf4 fix: 学习率调度按 optimizer step 计数并防止 warmup 越界
- total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率
- warmup_steps 用 min 截断,避免 lr_decay_steps 为负
2026-05-16 17:07:36 +08:00
ViperEkura 04c0dc7a47 refactor: Storage 改用工厂模式,server reload 接入 uvicorn
- 新增 StorageFactory(BaseFactory[BaseStorage]) 替代手写 dict 注册
- H5Storage / JSONStorage 通过 @StorageFactory.register 注册
- dataset.py 使用 StorageFactory.create() 替代 create_storage()
- 删除 create_storage / available_storage_types 死函数
- server.py reload 参数正式传入 uvicorn.run()
2026-05-16 17:00:26 +08:00
ViperEkura 48a53121ba refactor: 工厂 kwargs 过滤及组件参数清理
- BaseFactory.create() 按 __init__ 签名过滤多余 kwargs
- 移除 GQA/MLA/MLP/DeepSeekMoE 中多余的 **kwargs
- MLP/DeepSeekMoE 参数名统一为 dim_ffn
- scheduler max_seq_len 增加 None 显式判断
- 默认 max_prompt_len 提升至 2048
2026-05-16 16:47:41 +08:00
ViperEkura 0ba8c70ce1 fix: 修复 MLA 多个 bug 并缩小测试模型参数
- MLA kv_b_proj 输出维度和 q_rope 切分偏移修复
- 打通 MLA 配置从 ModelConfig 到 DecoderBlock 的传递路径
- rope_theta 配置不再被忽略,MLA 使用 qk_rope_head_dim
- tie_weight 使用 is True 避免 None 隐式生效
- norm_eps/rope base 类型标注修正
- 测试模型参数缩小 (dim=8, head_dim=4)
- 新增 6 种架构配置 × 2 场景的前向传播测试
2026-05-16 14:57:43 +08:00
ViperEkura 3d12a03909 docs : 拆分文档并补充类图缺失类和关系线
- 将 design.md 拆分为 architecture.md / inference.md / training.md
- 精简 dataflow.md 为纯数据管道
- 删除 design.md 和 introduction.md
- 更新 README.md 和 README-zh-CN.md 链接
- 补充 ChatMessage / AnthropicMessage 等 6 条孤立类关系线
- 补充 BaseModelConfig 和 TaskManager 两个缺失类
2026-05-15 23:38:26 +08:00
ViperEkura c169659611 docs: 修正 assets/docs/ 类图、数据流、参数文档及贡献指南
- design.md: 新增 ProtocolHandler/OpenAIHandler/AnthropicHandler 等缺失类
- design.md: 新增 Template Method、Storage 设计模式
- dataflow.md: 修正 GQA/MLA 为独立条目,补充 JSON 存储后端
- params.md: 标注 label_smoothing CLI 默认与 strategy 默认差异
- introduction.md: 修正 max_tokens 默认值 1024→2048
- CONTRIBUTING.md: 重写(纯 Python 无 conda、补充 CI 步骤与常见问题)
- .github/PULL_REQUEST_TEMPLATE.md: 修正 lint 命令,去除多余注释要求
- .github/ISSUE_TEMPLATE/bug_report.md: 修正 label(enhancement→bug)
2026-05-15 22:54:41 +08:00
ViperEkura e12f1a7ee5 feat: BaseModelConfig + DeepSeekMoE + 工厂模式替代 if/else
- BaseModelConfig: fields() 精确字段匹配 + 类型矫正 + 未知key警告
- DeepSeekMoE: 共享专家 + 路由专家 + top-K 门控
- AttnFactory/FFNFactory: 装饰器注册,DecoderBlock 零分支
- config 用 attn_type/ffn_type 驱动组件选择
2026-05-15 20:34:52 +08:00
ViperEkura ef25efffa2 refactor: 拆分 module.py 为 components 子包
- rope/linear/norm/embedding/mlp/attention/decoder_block 各自独立文件
- 依赖单向无循环
- 公开接口不变,外部无需修改
2026-05-15 20:08:36 +08:00
ViperEkura 19532440b4 chore: 版本号升至 1.3.5 2026-05-15 18:23:27 +08:00
ViperEkura 9096e413c3 refactor: RotaryEmbedding 合并 cos/sin 为单一复数缓存
- get_rotary_emb() 返回复数张量替代 Tuple[cos, sin]
- RotaryEmbedding 存储单一 freqs_cis buffer 替代分离的 cos_cached/sin_cached
- forward 中 view_as_complex 重建复数
2026-05-15 18:03:59 +08:00
ViperEkura 9d5e9fa6c4 perf: DDP 加 gradient_as_bucket_view/static_graph/broadcast_buffers,AdamW fused
- gradient_as_bucket_view=True 零拷贝梯度归并
- static_graph=True 跳过每轮 bucket 重建
- broadcast_buffers=False 省 buffer 广播
- AdamW fused=True 融合优化器 kernel
2026-05-15 15:30:24 +08:00
ViperEkura 08dde46778 fix: 修复训练循环 step/backward 顺序,重构为三重循环嵌套
- 训练循环改用 itertools.batched 实现 epoch→step→batch 三重嵌套
- on_step_begin 包裹 batch 循环,on_step_end 后接 optimizer.step/scheduler.step
- 修复首次 iteration=0 时 optimizer.step() 在 backward 之前触发的 bug
- GradientClippingCallback 改为 on_step_end(梯度已累积,step 前裁剪)
- SchedulerCallback 移除,schduler.step 由 trainer 在 optimizer.step 后直接调用
- metric_util 提取 _grad_stat 公共 helper,if param.grad: 修正为 is not None
2026-05-15 14:44:44 +08:00
ViperEkura 513f1f7826 perf: waiting_queue 改用 deque,pull_candidates 从 O(n²) 降到 O(1)
- list.pop(0) 每次左移全部元素,改 deque.popleft() 指针操作
- return_to_waiting 从 slice 整体复制改 appendleft 逐个插入
- 热路径 refill 阶段不再卡顿
2026-05-14 21:38:00 +08:00
ViperEkura e3382f6bb5 fix: 修复推理引擎 batch decode 中多项正确性与并发问题
- scheduler: decode 分组由幂次分桶改为精确 next_pos,消除 KV cache 位置错乱
- task: activate() 加锁操作 active_tasks,消除数据竞争
- engine: wait_completion 加超时,防止分配失败时永久死锁
- sample: TopKStrategy 向量化为 per-sample threshold,尊重各 task 的 top_k
- cache: Storage.write/gather 中 -1 页改用 mask 处理,防数据污染
- executor: prefill 逐 task 循环改为单次 tensor 调用
2026-05-14 21:31:39 +08:00
ViperEkura f0339022c1 fix: batch 推理示例添加 chat template 和 system prompt
- 新增 prompts 列表,对每个输入应用 apply_chat_template
- 添加 system message 到对话模板
2026-05-14 20:59:01 +08:00
ViperEkura d8da2cf17c docs: 修复文档中与源码不符的类名、方法签名和模块归属
- CONTRIBUTING.md: ruff/pytest 命令改为 conda 方式
- params.md: max_len → max_tokens
- introduction.md: max_len=1024 → max_tokens=None
- dataflow.md: PagedCache/CacheView → KVCache/KvcacheView
- design.md: 全面修正类图(PagedCache→Allocator等6个新类、删除position_ids误参、修正BaseDataset字段和25+条关系线、Module Overview更新)
2026-05-14 20:26:24 +08:00
ViperEkura 205b40bd28 refactor: 重构 cache 和 inference 参数体系,分离存储与分配
- 合并 GenerationRequest/GenerationParams,统一 max_tokens 参数名
- PagePool/PrefixCache 分离为 Allocator + PrefixCache + PagePool
- 拆分 KV 存储为独立 Storage 类,PagedCache → KVCache,CacheView → KvcacheView
- Allocator.inc_ref 移除 LRU 防止竞争,Storage.write 增加负页防御
- Allocator/PrefixCache/TaskTable 加 threading.Lock 保证线程安全
- server.py uvicorn.run 改为传 app 对象修复导入错误
- benchmark.py 适配 KVCache 新 API
2026-05-14 20:05:08 +08:00
ViperEkura 18fe6e9339 refactor: 消除多处重复模式,统一工厂和参数传递
- AutoModel 继承 BaseFactory,消除自建 Registry(-30 行)
- executor.execute_prefill 删除重复 forward 代码块(bug)
- train_callback 移除 Protocol 上矛盾的 issubclass 检查
- engine.py 内部方法统一传 GenerationParams,校验内聚
- protocol.py SSEBuilder 类→函数,handle() 用 GenerationParams
- StreamContext 动态属性改为显式 dataclass 字段
- BaseFactory 新增 get_component_class 方法
2026-05-14 18:00:50 +08:00
ViperEkura 2196c34c52 refactor: 重构 inference 模块架构,引入设计模式并分组文件
- 新增 protocol.py 协议层,Template Method 模式消除流/非流分支 45% 重复
- SSEBuilder 统一 SSE 构造,StopChecker 独立 stop_sequence 检测
- AnthropicHandler 追踪已产出文本,修复 stop 时重复 delta
- server.py 路由从约 100 行缩减至 3 行
- 拆分为 core/(cache/executor/scheduler/task)和 api/(protocol/server)
- 外部保持二级导入路径(from astrai.inference import Name)
- 删除所有分隔线注释,代码按语义自然分组
2026-05-14 17:42:37 +08:00
ViperEkura 466c2e1efd fix: process_attention_mask 中 expand 后的 inplace 写导致 alias 报错
- pad.view.expand 产生的视图多元素指向同一内存,attend &= 写入报错
- 改为 .expand().clone() 独立内存后再 inplace
2026-05-14 16:30:31 +08:00
ViperEkura 7e26d848ab perf: apply_rotary_emb 改用复数乘法
- get_rotary_emb 保留 cos/sin 实数存储,forward 组合为 complex
- apply_rotary_emb 用 view_as_complex 复数乘法替代多次 view mul stack
- 移除 GQA MLA DecoderBlock 中的 Tuple Tensor Tensor 类型
- 解码从 4.24s 降到 3.49s
2026-05-14 16:20:16 +08:00
ViperEkura ed95ef245c perf: 消除 RotaryEmbedding.forward 中 position_ids GPU 同步
- cos/sin 缓存预分配到 max_len,移除运行时动态扩容逻辑

- 移除未使用的 max_len_cached 属性

- 解码累计从 4.23s → 3.99s(+5.7%)
2026-05-14 15:53:21 +08:00
ViperEkura 6d6ef99e66 perf: 消除 PagedCache.write 中的 position_ids GPU 同步,解码提速 15%
- CacheView.write 用 total_len - k.size(1) 推导 start_pos,替代 position_ids[0,0].item()

- 移除 GQA/MLA/DecoderBlock 中不再使用的 position_ids 参数

- PagedCache.write 参数 position_ids:Tensor → start_pos:int
2026-05-14 15:37:48 +08:00
ViperEkura a8e2a1ba45 docs: 修正文档中与源码不符的类名、方法签名和模块归属
- Transformer/DecoderBlock/GQA/RotaryEmbedding forward 签名 start_pos → position_ids

- _Result → GenerateResult

- save_h5/load_h5 从 serialization 移至 dataset 模块

- PagedCache UML 移除内部 PagePool 属性

- 修正 Layer 数不一致(24 vs 32)及 decode 位置分组描述

- 更新文档时间为 2026-05-14
2026-05-14 15:04:53 +08:00
ViperEkura 6269bacfc3 refactor: decode 按页分桶批处理,position_ids 改为 per-task 构建 2026-05-14 14:22:11 +08:00
ViperEkura c0effc9f5b refactor: 位置编码改用 position_ids [B,S],简化 attention mask 构建
- RotaryEmbedding/CacheView 接受 position_ids 替代 start_pos

- process_attention_mask 用 position_ids >= arange 做逐位置 causal

- 训练/无 KV cache 时 position_ids=None 内部自动处理

- 移除 executor/benchmark 中冗余的 input_mask 构造
2026-05-14 13:26:31 +08:00
ViperEkura df0845e916 chore: 解耦 Executor/Scheduler/TaskManager,修复 stop 页泄漏,移除 ServerState 全局单例 2026-05-12 13:47:55 +08:00
ViperEkura 7440e9c809 style: 重命名 test_scheduler_concurrency 为 test_scheduler 2026-05-12 12:24:36 +08:00
ViperEkura 7d4029c2a4 test: inference 模块补全单元测试,cache/sample/engine/task
- test_cache: page_hash, PagePool, PrefixCache, TaskTable, PagedCache write/gather
- test_sample: TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample()
- test_engine: _Result 线程安全, generate stream/non-stream batch/single
- test_task: Task 生命周期, TaskManager 队列操作
- 4 新文件, +771 行, 116 total tests
2026-05-12 12:17:57 +08:00
ViperEkura 0ca6c9e6eb test: 增加 13 个边界条件测试,不需要 base_test_env 的函数移除该参数
- Fetcher 空/边界/跨段测试
- Storage 未加载 fetch 异常
- detect_format 无效路径/不支持格式
- create_storage 无效类型
- JSON pre-tokenized 无 tokenizer
- load_json 跳过 config.json
- Dataset 未加载/数据过短
- 所有 import 提到文件顶部
2026-05-12 11:47:30 +08:00
ViperEkura 6e49d27057 fix: MultiSegmentFetcher 空 dict 崩溃 + BaseDataset assert 替换为显式 raise
- MultiSegmentFetcher.__len__: min([]) → 加空检查返回 0
- BaseDataset.get_index: assert 替换为 RuntimeError / ValueError
- BaseDataset.__len__: assert 替换为 early return 0
2026-05-12 11:41:45 +08:00
ViperEkura 5203b7f53e perf: 测试优化,model 改为 session 共享,scheduler 用 Event 替代 sleep
- 拆出 session-scoped test_tokenizer + test_model,14 次创建 → 1 次
- 删除无用 test_env fixture
- 固定模型维度,消除随机性
- 添加 pytest markers 配置
2026-05-12 11:35:18 +08:00
ViperEkura 5889179c54 refactor: 抽取 BaseStorage 存储抽象,支持 JSON 原始文本数据加载
- 新增 astrai/dataset/storage.py:BaseStorage/H5Storage/JSONStorage + Fetchers + 序列化函数
- BaseDataset.load() 接入存储抽象,自动检测 HDF5/JSON 格式
- JSON 支持原始文本 + tokenizer callable 加载时 tokenize
- 新增 BaseDataset.count / keys 属性进行长度观测
- serialization.py 精简为只保留 Checkpoint 类
- 函数放前、类放后,删除分隔注释
2026-05-12 11:17:24 +08:00
ViperEkura 38e18fdfd3 refactor: PagedCache Facade 模式,提取 PagePool/PrefixCache/TaskTable
- cache.py: 提取 PagePool (位图+LRU)、PrefixCache (前缀哈希)、TaskTable (任务页表)
  PagedCache 降为 Facade 组合三者 + 张量存储,公开 API 不变
- executor.py: 移除 allocate_pages_for_activation/free_task_pages/get_cached_tokens
  三冗余委托方法,去掉 page_size 构造参数(改用 page_cache.page_size)
- scheduler.py: 直接调用 self._page_cache.* 代替已移除的 Executor 委托
- 移除 CacheView.__slots__、PagePool.ref_count、PagedCache.alloc/pages_needed/inc_ref
  PrefixCache.evict 等死/冗余方法
2026-05-11 15:22:21 +08:00
ViperEkura 4753958f92 refactor: 页状态移入 PagedCache,Task 纯化为域对象
- PagedCache 增 task_alloc/task_free/task_extend/task_cached/task_record_hashes/make_table_tensor
- Task 移除 page_table/n_pages/_prefix_cached_tokens/_pages_freed
- Executor 移除 _PageState,页操作全部委托 PagedCache
- CacheView.gather 截断逻辑下沉到 PagedCache.gather
- 各类补充单行职责 docstring
2026-05-11 14:42:39 +08:00
ViperEkura 73d6cc0f26 refactor: TaskManager 剥离页管理,STOP 移至 task.py
- TaskManager 移除 page_cache/page_size 依赖,增 pull_candidates/activate/return_to_waiting
- Executor 增 allocate_pages_for_activation/free_task_pages,承接全部页操作
- STOP 从 cache.py 移至 task.py
- scheduler loop 显式装配: 清理→释页 / 拉取→分配→激活
- sampling.py → sample.py
2026-05-11 14:04:31 +08:00
ViperEkura 317ed90bac refactor: 拆分 scheduler 为 TaskManager + Executor
- InferenceScheduler 退化为编排器,委托 TaskManager 管理任务生命周期 + Executor 执行模型前向
- Task/TaskStatus/TaskManager 移至 task.py
- Executor 移至 executor.py (原 BatchExecutor)
- scheduler.py 437 行 -> 142 行
2026-05-11 13:50:11 +08:00
ViperEkura 951df8155c perf: gather 向量化 2026-05-10 21:01:03 +08:00
ViperEkura a58fab8d6e fix: max_seq_len 检查改为仅 prompt 超限发 STOP,max_tokens 超出部分 clamp 2026-05-10 20:17:47 +08:00
ViperEkura a3c8296135 fix: page cache 分配失败越界崩溃 + 长度超限终止
- astrai/inference/scheduler.py: add_task 增加 max_seq_len 检查,超限时直接发 STOP 信号终止
- astrai/inference/scheduler.py: _maybe_alloc_page 返回 bool,alloc 失败时标记 ABORTED + 发 STOP
- astrai/inference/scheduler.py: _execute_decode 过滤分配失败任务,避免 page_table 越界
- astrai/inference/scheduler.py: _remove_finished_tasks 清理 ABORTED 任务并释放 pages
- astrai/inference/scheduler.py: _execute_prefill input_mask 改为覆盖全部 prompt_len
- astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列
2026-05-10 20:14:38 +08:00
ViperEkura c95ace41aa fix: prefill 时 attention mask 长度不足导致 expand 崩溃
- astrai/inference/scheduler.py: prefill input_mask 由 [batch, seq_len] 改为 [batch, prompt_len],覆盖全部 KV 位置
- astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列,避免 expand 非 singleton 维度不匹配
2026-05-10 19:56:41 +08:00
ViperEkura 3da428e0e4 perf: PagedCache 持久前缀缓存 + LRU 逐出
- astrai/inference/cache.py: refcount 归零时保留 hash 映射,页加入 LRU evictable 池
- alloc() 无空闲页时从 LRU 逐出,优先释放 _free_mask
- lookup_prefix/inc_ref 触发 _touch 更新 LRU 序
- record_page 设置 pin 标记并从 LRU 移除
2026-05-10 18:05:11 +08:00
ViperEkura 133a9de98f feat: _generate_streaming 支持 batch 模式
- _Result.append 存储 (idx, token) 元组,pop_all 返回对应列表
- 单 prompt: Generator[str](向后兼容)
- 多 prompt: Generator[Tuple[int, str]],token 交错到达,调用方自行分流
- 不使用 dispatch 线程 / Queue,避免同步开销和内存积压
2026-05-10 17:42:20 +08:00
ViperEkura 523eacf5fe release: v1.3.4
- refactor: 分页 KV cache(PagedCache+CacheView)替换固定 slot,删除 PrefixCache
- refactor: 推理引擎控制逻辑重写,修复连续批处理核心缺陷、线程安全问题
- refactor: KV 缓存槽位下沉到注意力层,移除 _remap_kv / _writeback_kv
- refactor: 统一采样路径为 SamplingPipeline batch tensor,删除 apply_sampling_strategies
- refactor: 设计模式优化 inference 模块导入结构(cache/sampling 独立)
- feat: 推理引擎前缀缓存(KV cache 复用)
- feat: OpenAI 兼容 chat completion API(流式+非流式+usage)
- feat: Anthropic 兼容 /v1/messages API,移除旧版 /generate 端点
- feat: GRPO CLI 接入 + on-policy,OpenAI API top_k 参数化
- feat: Checkpoint 支持 extra 通用扩展数据
- feat: Docker Compose 一键部署(GPU/CPU 双模式)
- feat: GRPO 训练参数补充,批处理训练参数表
- fix: 调度器延迟优化 — 移除 5ms 睡眠,修复 refill 任务丢失
- fix: CLI 参数缺失/重复、device_ids 越界、generate 参数名不一致
- fix: 长对话截断方向错误,保留最新 token 而非最早
- fix: remove_task 未释放 KV cache slot 导致第二轮对话死锁
- fix: KV cache 槽位索引错位、版本校验缺失、注意力掩码
- fix: scheduler 越界 bug,SchedulerCallback 回调阶段修正
- perf: _Result 改用 Condition.wait_for 消除非流式 CPU 空转
- perf: decode 每步张量预分配;input_ids 改用一次构建代替逐元素赋值
- refactor: 移除 device_ids 参数,统一 CUDA_VISIBLE_DEVICES
- docs: 更新文档以匹配分页 KV cache 等代码重构
- docs: 修正多处文档错误、补充训练参数说明
2026-05-10 15:59:18 +08:00
ViperEkura cffedaad5e perf: 消除非流式推理 CPU 空转并减少 decode GPU 张量冗余分配
- engine.py: _Result 改用 threading.Condition.wait_for 替代
  Event busy-wait,非流式模式线程被内核挂起而非 1760 万次空转
- scheduler.py: _execute_decode 将 temperature/top_k/top_p 张量
  移至循环外预先分配,避免每步重复 torch.tensor();input_ids
  改用 torch.empty 避免不必要的 zero 初始化(两处均为完全覆盖)
- _execute_prefill: input_ids 同改为 torch.empty
2026-05-10 15:32:11 +08:00
ViperEkura 3583c46b66 feat: 推理引擎前缀缓存(KV cache 复用)
- cache.py: 新增模块级 page_hash() 多项式滚动哈希函数;PagedCache 新增
  record_page/lookup_prefix/inc_ref,free() 自动清理哈希映射
- scheduler.py: Task 新增 _prefix_cached_tokens;_refill_active_batch 先查
  缓存命中页(inc_ref)再分配剩余页;合并 _execute_prefill 为单一方法,
  按 (prompt_len, start_pos) 分组批量执行全量/部分 prefill;
  _record_page_hashes 注册完整页哈希;修复 device/dtype 默认值从硬编码
  改为 None(自动检测模型设备)
- test: mock model 补充 dtype/device 适配自动检测
2026-05-09 23:53:57 +08:00
ViperEkura ca4e6b907c feat: Checkpoint 支持 extra 通用扩展数据,用户通过函数自定义保存/恢复优化器等状态
- serialization.py: Checkpoint 新增 extra: dict 字段,
  save() 写入 extra.pt,load() 自动恢复
- train_callback.py: CheckpointCallback 新增 save_extra_fn
  参数,用户传入 (context) -> dict 决定保存哪些额外状态
- train_context.py: TrainContextBuilder 新增 load_extra_fn
  参数,用户传入 (extra, context) 从 checkpoint 恢复状态
2026-05-09 15:50:38 +08:00
ViperEkura db99d8b254 fix: 修复文档多处不准确 + inference scheduler 越界 bug + SchedulerCallback 回调阶段修正
文档 (6 个文件):
- design.md: 15+ 处修正 — persistent_key_values→paged_cache,
  MLA 字段重写, Server/ParallelSetup 不存在类移除,
  关系箭头方向修复, SchedulerCallback 阶段修正等
- dataflow.md: 重写数据流图和描述, 修复训练回调顺序、
  数据键名、MLA 归属、MetricTracker 等错误
- introduction.md: 层数 32→24, MLP 图双 Linear 修正,
  默认值/响应字段/health 端点修复
- params.md: 补充 grpo 及 4 个 GRPO 参数
- README.md / README-zh-CN.md: generate.py 补全必需参数,
  删除重复注释, HuggingFace 声明修正

代码 (2 个文件):
- scheduler.py: n_pages 池加 page_size 余量防止越界;
  decode 前预分配页
- train_callback.py: SchedulerCallback 从 on_step_end 改
  回 on_batch_end (按 batch 步进学习率)
2026-05-09 15:40:17 +08:00
ViperEkura b98c9cefdc refactor: 移除 device_ids 参数设计,统一通过 CUDA_VISIBLE_DEVICES 控制 GPU 分配;更新 README 训练示例
- setup.py: 移除 device_ids 参数,setup_parallel 直接用 rank 作为设备索引
- train_config.py: 移除 device_ids 字段
- trainer.py: 不再传递 device_ids
- train.py: ddp_wrap 用 get_rank() 直接取值
- README.md, README-zh-CN.md: 训练示例改为多行命令风格,去掉参数表格
2026-05-09 14:55:43 +08:00
ViperEkura 283bcaf2ff fix: 修复 CLI 参数缺失/重复、device_ids 越界、generate 参数名不一致、scheduler 时序、非流式截断等 bug
- train.py: 补上 --batch_size、--grpo_clip_eps,删除 3 处重复 --group_size
- generate.py: --model_dir 改为 --param_path 对齐 README
- automodel.py: from_pretrained 新增 strict 参数(默认 True)
- parallel/setup.py: 修复 device_ids 索引越界
- train_callback.py: scheduler.step() 移至 on_step_end
- test_train_strategy.py: 测试中补 optimizer.step()
- engine.py: 非流式改为循环等待所有任务完成,补 remove_task 清理
- scheduler.py: Task 添加 _pages_freed 标志,杜绝双重释放
- trainer.py: accumulation_steps=0 时 clamp 为 1
- tokenizer.py: save_pretrained 添加 _tokenizer is None 检查
- benchmark.py: 修复 ModelConfig 过时 import 路径
- inference/__init__.py: 修复 stale docstring
2026-05-09 14:36:42 +08:00
ViperEkura bc7c82977e feat: GRPO CLI 接入 + on-policy,OpenAI API top_k 参数化,补充训练参数表
- train.py 新增 --train_type=grpo 及参数 (--grpo_clip_eps, --grpo_kl_coef, --group_size, --grpo_sync_interval, --start_epoch)
- GRPOStrategy 统一 on-policy 模式,ratio = exp(logπ_θ - logπ_ref),PPO 裁剪目标,sync_interval 自动同步 ref_model
- ChatCompletionRequest 新增 top_k 参数,不再硬编码
- 补充 README 完整训练参数表(含此前缺失的 max_grad_norm / adamw / window_size / stride 等)
2026-05-09 12:22:33 +08:00
ViperEkura 34a511e36e feat: 新增 Docker Compose 一键部署,支持 GPU/CPU 双模式 2026-05-09 11:57:46 +08:00
ViperEkura d73f52a2f8 feat: 新增 Anthropic 兼容 /v1/messages API,移除旧版 /generate 端点
- 新增 /v1/messages 端点,兼容 Anthropic Messages API 格式
- 支持流式 SSE(message_start → content_block_delta → message_stop)
- 支持 system 顶层提示词与 stop_sequences 停止序列
- 新增 AnthropicMessage / MessagesRequest Pydantic 模型
- 移除旧版 /generate 端点及相关测试用例
- 更新 README.md / README-zh-CN.md / introduction.md 文档
2026-05-09 11:47:22 +08:00
ViperEkura 9d96b0431d docs: 更新文档以匹配分页 KV cache 等代码重构 2026-05-08 22:41:13 +08:00
ViperEkura f81e2b4a73 feat: OpenAI 兼容的 chat completion API(流式+非流式+usage) 2026-05-08 21:54:55 +08:00
ViperEkura 4e324d8f26 fix: benchmark 改用 PagedCache 替代已删除的 persistent_key_values 2026-05-08 21:26:55 +08:00
ViperEkura 6ed0506491 fix: 减少调度器延迟 — 移除解码路径 5ms 睡眠,修复 refill 任务丢失 bug 2026-05-08 21:13:52 +08:00
ViperEkura 30cc2d67a4 refactor: 分页 KV cache 替换固定 slot,删除 PrefixCache 及相关死代码
- 用 PagedCache + CacheView 替换固定 slot 式 KV cache,attention 层只通过 page_table 间接索引
- 删除 PrefixCache(radix tree)及 scheduler 中所有 prefix cache 命中/插入/释放逻辑
- 删除无用函数:pin、version、free_count、_mark_seq_mask 及 seq_mask 分配
- 修复 write 在多页 prefill 时 offset 为负导致 chunk 计算错误
- _make_page_table_tensor 改用 list 拼接一次 tensor,去掉逐元素赋值
- 清理 model 接口参数:kv_cache, slot_indices → paged_cache(CacheView)
- 精简 docstring 为单行,删除冗余 section 注释和旧代码
- 修复 test_scheduler_concurrency.py 缺少 import pytest
2026-05-08 20:44:05 +08:00
ViperEkura 7ddebf2cd9 refactor: 统一采样路径为 Strategy + batch tensor,删除 apply_sampling_strategies
- TemperatureStrategy / TopKStrategy / TopPStrategy 支持 Union[float, Tensor]
- SamplingPipeline.sample() 一条调用完成 apply + softmax + multinomial
- 新增 sample() 独立函数作为 scheduler 入口
- scheduler decode 改为 batch tensor 参数传递,支持任意 batch size
- 删除 apply_sampling_strategies(被 sample() 取代)
2026-05-08 19:07:14 +08:00
ViperEkura 78dc2bd41c docs: 修正文档错误并补充训练参数说明
- README: 补充训练参数速查表,完善训练命令示例
- design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等
  新增类),修正参数名和类型错误,统一泛型符号
- params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数
  (grpo_*、model_type、resume_dir),补充完整示例
- dataflow.md: _RadixNode 命名修正
2026-05-08 18:07:57 +08:00
ViperEkura 44d7a4e959 refactor: 设计模式优化 inference 模块导入结构
- 新建 cache.py:SlotAllocator 对象池 + PrefixCacheManager

- 新建 sampling.py:Temperature/TopK/TopP 可组合策略

- TaskStatus 改用 Enum,GenerationParams 值对象模式

- _STOP 移至 cache.py,解除 engine→scheduler 轻量耦合

- 更新测试导入路径,ruff 格式检查通过
2026-05-08 16:57:57 +08:00
ViperEkura c4401512f2 fix: 修复长对话截断方向错误,保留最新 token 而非最早
- add_task 中 prompt 超长时改为保留末尾 token(prompt_ids[-max_prompt_len:])
  而非开头 token,确保多轮对话时模型能看到最近的提问上下文
2026-05-08 15:52:48 +08:00
ViperEkura a6f5ff3b37 fix: 修复 remove_task 未释放 KV cache slot 导致第二轮对话死锁
- remove_task() 现在释放 KV cache slot 和 prefix cache 引用
- _refill_active_batch 中 alloc 失败时将剩余 task 推回 waiting_queue
- 主循环增加 try/except 异常兜底,发送 _STOP 给所有 task
- 重构:server.py 全局变量改为 ServerState 类;automodel.py
  使用 Registry 替代裸 dict;合并 TrainContextBuilder 的 with_*
  方法到 build()
2026-05-08 14:53:04 +08:00
ViperEkura ffff05b2c6 refactor: 替换魔法字符串为_STOP sentinel,修复generator清理逻辑 2026-05-06 20:37:16 +08:00
ViperEkura b89f8436ea refactor: 将KV缓存槽位映射下沉到模型注意力层,移除_remap_kv和_writeback_kv 2026-05-06 20:01:22 +08:00
ViperEkura 123f25e339 fix: 修复KV缓存槽位索引错位、版本校验缺失与注意力掩码问题,合并预填充方法 2026-05-06 19:51:14 +08:00
ViperEkura 520de3ebe8 refactor: 重构推理引擎控制逻辑,修复连续批处理核心缺陷
- 修复 decode 阶段新任务覆盖已有任务的严重缺陷
- 修复线程安全问题(热路径无锁竞争)
- 修复前缀缓存引用计数管理不当导致缓存被驱逐
- 修复 pad_id 缺失导致全量 prefill 崩溃
- 修复 RoPE 位置错乱(不同位置任务共用 start_pos)
- 新增 slot 版本追踪实现前缀缓存零拷贝复用
- 新增异步流式生成接口避免阻塞事件循环
- 添加完整英文文档字符串
2026-05-06 16:04:06 +08:00
ViperEkura 466c34d7a8 ci: 添加 Docker 镜像自动构建工作流 2026-04-10 13:09:58 +08:00
ViperEkura 6831a15424 docs: 更新镜像构建部分说明 2026-04-10 12:59:50 +08:00
ViperEkura 0f9e5c5049 build: 修改docker 配置 2026-04-10 12:53:08 +08:00
ViperEkura cb0e7f2a80 build: 修改docker 构建流程 2026-04-10 11:25:00 +08:00
ViperEkura 296db909aa docs: 更新设计文档 2026-04-09 20:05:54 +08:00
ViperEkura a2ae742988 chore: 增加并发测试 2026-04-09 18:10:28 +08:00
ViperEkura 29beb174a5 fix: 修复删除节点问题 2026-04-09 16:58:29 +08:00
ViperEkura bbeaff4c60 refactor: 精简推理引擎代码,优化参数传递规范 2026-04-09 14:17:48 +08:00
ViperEkura ab5e207f42 feat: 增加缓存处理 2026-04-08 20:54:14 +08:00
ViperEkura b0eff02446 chore: 修改RMSNorm 实现 2026-04-06 20:27:01 +08:00
ViperEkura 408f0cb513 docs: 更新网络接口文档 2026-04-06 13:39:51 +08:00
ViperEkura 64b78ecce3 fix: 增加旋转位置编码扩展 2026-04-06 13:29:39 +08:00
ViperEkura f2ffdf60d0 chore: 修改错误拼写 2026-04-06 10:37:19 +08:00
ViperEkura ace8f6ee68 chore: 优化未使用的模块 2026-04-06 09:54:17 +08:00
ViperEkura a57a16430d fix: 修复tokenizer存储的问题 2026-04-06 09:36:29 +08:00
ViperEkura 3fee87897d chore: 修改拼写错误问题 2026-04-06 09:28:16 +08:00
ViperEkura 3f67e53088 fix: 修复tokenizer 参数问题 2026-04-06 09:22:46 +08:00
ViperEkura bf7adb35b3 docs: 更新文档 2026-04-06 00:50:37 +08:00
ViperEkura feaa3fca36 ci: 优化 GitHub Actions 工作流 2026-04-05 22:40:16 +08:00
ViperEkura 39766aa1dc chore: 修改类名,优化导入顺序 2026-04-05 22:27:57 +08:00
ViperEkura 9b22b1651e refactor: 优化工具脚本接口并修复批处理问题 2026-04-05 21:56:22 +08:00
ViperEkura e58dbd7c57 chore: 精简实现代码部分 2026-04-05 21:16:38 +08:00
ViperEkura d2fe8afbd1 chore: 更新文档, 修正代码格式 2026-04-05 20:59:52 +08:00
ViperEkura 23ce4bc3ae fix: 修复异常处理问题 2026-04-05 20:44:35 +08:00
ViperEkura d2b36cc85d fix: 修复特殊token 的问题 2026-04-05 20:09:47 +08:00
ViperEkura fc278d17ab feat: 实现模型动态注册机制 2026-04-05 19:38:12 +08:00
ViperEkura ff43a2fab8 docs: 更新设计文档 2026-04-05 00:17:35 +08:00
ViperEkura 2b26f03bd3 refactor: 拆分engine.py 文件 2026-04-05 00:07:21 +08:00
ViperEkura 861d33b1a1 refactor: 更新inference 部分的实现 2026-04-04 23:49:18 +08:00
ViperEkura 99b821ebf5 docs: 更新文档类图等 2026-04-04 18:11:36 +08:00
ViperEkura c94a246c71 chore: 重命名目录 2026-04-04 17:03:22 +08:00
ViperEkura 2dc9545d7f refactor: 实现 chat template 分派设置 2026-04-04 16:56:31 +08:00
ViperEkura 9c31d78a22 chore: 将data 模块命名为dataset 2026-04-04 16:16:27 +08:00
ViperEkura bd9741dc5f refactor: 从data 模块分离tokenizer 2026-04-04 16:12:58 +08:00
ViperEkura b531232a9b style: 修改为显式导入 2026-04-04 16:02:49 +08:00
ViperEkura 3346c75584 feat: 优化工厂模式的实现 2026-04-04 15:49:46 +08:00
ViperEkura aa5e03d7f6 fix: 修复工厂模式问题并增加chat-template设置 2026-04-04 12:05:05 +08:00
ViperEkura 073baf105c chore: 修复docker配置问题 2026-04-04 11:35:14 +08:00
ViperEkura e97536758f refactor: 优化工厂模式结构 2026-04-04 11:33:58 +08:00
ViperEkura 7861af12e4 chore: 增加docker 配置 2026-04-04 10:59:32 +08:00
ViperEkura 7f0552013a chore: 增加提交检测脚本 2026-04-04 10:43:24 +08:00
ViperEkura 3535de5cc4 fix: 同步device 和 dtype 2026-04-04 10:25:39 +08:00
ViperEkura 26989e54aa feat: 优化server 部分设置 2026-04-04 01:41:01 +08:00
ViperEkura 70d52935f0 fix: 修复参数问题 2026-04-03 23:34:21 +08:00
ViperEkura c0e0e6afd9 docs: 更新文档 2026-04-03 22:11:19 +08:00
ViperEkura 0852b852f8 refactor: 优化参数传递,清理导入样式 2026-04-03 22:06:32 +08:00
ViperEkura 3a7d98a950 fix: 修复测试部分导入问题 2026-04-03 15:01:39 +08:00
ViperEkura c5560740b6 refactor: 修改分词器部分结构, 更新特殊token等 2026-04-03 14:52:35 +08:00
ViperEkura 94c6a015c8 chore: 更新ignore 2026-04-03 14:31:05 +08:00
ViperEkura 8b6509b305 docs: 更新 design.md 项目结构和模块文档 2026-04-02 20:11:19 +08:00
ViperEkura 912d7c7f54 chore: 更新脚本并且修改gitignore 2026-04-02 15:40:31 +08:00
ViperEkura 475de51c7d feat: 增加server, 并且修改测试单元 2026-04-02 15:05:07 +08:00
ViperEkura 9f1561afe7 reafactor: 修改ModelParameter 2026-03-31 16:00:55 +08:00
ViperEkura 80c0b20877
Update issue templates 2026-03-31 15:20:21 +08:00
ViperEkura e7721eafc6 docs: 更新说明内容 2026-03-31 15:18:49 +08:00
ViperEkura 4ead0a20cf chore: 修改文件夹结构 2026-03-31 10:14:08 +08:00
ViperEkura b1527d9575 docs: 优化文档结构并添加 GitHub 模板 2026-03-31 10:00:49 +08:00
ViperEkura 2e009cf59a chore: 更新项目名称 2026-03-31 09:34:11 +08:00
ViperEkura 780b9e1855 fix: 修复参数传递问题 2026-03-31 01:23:29 +08:00
ViperEkura aef7615abd docs: 更新README 2026-03-31 00:50:01 +08:00
ViperEkura 50488bd659 chore: 简化格式并更新文档 2026-03-31 00:28:58 +08:00
ViperEkura eb57e55fca chore: 更新计算顺序 2026-03-30 23:35:22 +08:00
ViperEkura 426af2d75f style: 使用ruff 工具优化代码风格 2026-03-30 23:32:28 +08:00
ViperEkura 345fd2f091 fix: 修复参数传递问题 2026-03-30 22:22:36 +08:00
ViperEkura e1f9901384 build: 更新设置 2026-03-30 21:44:50 +08:00
ViperEkura 0e7fc623b4 fix: 修复部分已知问题 2026-03-30 21:42:00 +08:00
ViperEkura 3e33c14376 reafactor: 统一并增强项目中的工厂模式实现 2026-03-30 01:33:14 +08:00
ViperEkura 60f4df95bd fix: 修复一些已知问题 2026-03-30 01:08:19 +08:00
ViperEkura c01791ff54 feat: 增加推理部分工厂模式 2026-03-30 00:55:15 +08:00
ViperEkura 980299cd54 fix: 修复参数传递问题 2026-03-20 21:54:13 +08:00
ViperEkura 3e8f2eba81 fix: 修复路径问题 2026-03-20 21:14:02 +08:00
ViperEkura 361cdeb296 chore: 修改策略命名 2026-03-19 23:08:41 +08:00
ViperEkura 50f76cd7c7 refactor: 重构数据模块中的数据集类命名和文件结构 2026-03-19 22:37:32 +08:00
ViperEkura 0f518473af fix: 修复强化学习算法问题 2026-03-19 22:23:51 +08:00
ViperEkura a5574f92e2 feat: 初步实现grpo 算法逻辑 2026-03-19 20:56:53 +08:00
ViperEkura abcedf892e feat: 增加 MLA 模块 2026-03-18 16:41:46 +08:00
ViperEkura abc3a06266 chore: 增加ppl计算工具并优化代码格式 2026-03-18 16:16:02 +08:00
ViperEkura 62fba9a298 refactor: 优化接口设置, 去除冗余代码 2026-03-18 15:07:35 +08:00
ViperEkura e23a5ca426 fix: 修复metric 保存时机的问题 2026-03-16 20:07:36 +08:00
ViperEkura e55b57d771 fix: 修复梯度平均问题 2026-03-13 23:00:26 +08:00
ViperEkura c4feab96fe fix: 统一state_dict 处理方式 2026-03-13 22:41:56 +08:00
ViperEkura e35cb0d84a feat: 增加 label smoothing 设置 2026-03-13 22:37:27 +08:00
ViperEkura 6d6ef6dbb6 refactor: 修改project logo 2026-03-06 12:15:49 +08:00
ViperEkura 493fe4e84b feat: 增加 label smothing 2026-03-06 11:41:14 +08:00
ViperEkura 82d22c5742 fix: 修复callback 时机不一致的问题 2026-03-06 10:51:22 +08:00
ViperEkura 96744ac2d2 refactor: 修改metric_util.py 2026-03-06 10:33:44 +08:00
ViperEkura 2331713fde refactor: 修改训练脚本 2026-03-05 14:40:26 +08:00
ViperEkura c74fbf84b7 build: 增加h5py 版本号 2026-03-04 21:29:37 +08:00
ViperEkura 5a8c442315 docs: 修改 README 2026-03-04 20:51:09 +08:00
ViperEkura c7d0448822 fix: 修复StepMonitorCallback序列化问题 2026-03-04 20:38:07 +08:00
ViperEkura 1d43a1785e build: 修改dependencies 以及版本号 2026-03-04 20:13:38 +08:00
ViperEkura 5713b55500 refactor: 修改 StepMonitorCallback, 分离职责 2026-03-04 19:45:39 +08:00
ViperEkura b53e10aac4 refactor: 修改metric 监测部分 2026-03-03 16:08:50 +08:00
ViperEkura dff58468d6 fix: 修复 load_h5 丢失文件的问题 2026-03-02 17:37:28 +08:00
ViperEkura 8a8d6369bc fix: 修复 dataset 和 checkpoint 的 bug 2026-03-02 11:12:21 +08:00
ViperEkura 80e17418b4 fix: 修复一些运行时问题 2026-03-01 15:47:07 +08:00
ViperEkura 6089a12cef fix: 修复参数传递问题并更新测试单元 2026-02-28 19:01:16 +08:00
ViperEkura b17cc6a6fb refactor: 修改参数传递方案 2026-02-28 18:09:00 +08:00
ViperEkura a33d086883 build: 修改build 方式 2026-02-27 17:52:28 +08:00
ViperEkura e9f42ec8b1 Change license from Apache 2.0 to GPL v3.0 2026-02-22 21:20:34 +08:00
ViperEkura 582d4ae9a7 refactor(data): 修改文件加载方案 2026-02-22 21:14:10 +08:00
ViperEkura 0ca4871e80 ci(spell-check): 修改检查流程 2026-02-11 16:01:53 +08:00
ViperEkura 99ef8fda71 feat(inference): 增加cuda_graph 装饰器 2026-02-07 21:14:39 +08:00
ViperEkura dbd57e30e5 feat(inference): 增加cuda graph 设置 2026-02-07 15:42:41 +08:00
ViperEkura a5869d89ba feat(trainer): 增加state_dict 存储设定 2026-02-04 19:47:21 +08:00
ViperEkura 7a9b9d0659 docs(architecture): 添加系统架构文档并修复KV缓存数学公式 2026-01-18 14:10:31 +08:00
ViperEkura 75758ead46 docs(data): 修改内存映射文件扩展名为.pt 2026-01-16 21:02:26 +08:00
ViperEkura 7dfa5cc0ac refactor(data): 重构MmapFileHandler类并改进数据加载机制 2026-01-11 19:37:28 +08:00
ViperEkura 9dab96c31f test(checkpoint): 添加多进程检查点测试功能 2026-01-08 22:04:39 +08:00
ViperEkura ff5c8a71f5 fix(trainer): 修复回调函数合并逻辑 2026-01-08 21:56:44 +08:00
ViperEkura 4da70785b5 refactor(tests): 重构测试文件目录结构 2026-01-08 21:34:52 +08:00
ViperEkura d407962ffa fix(trainer): 更新检查点保存和加载逻辑 2026-01-08 19:04:08 +08:00
ViperEkura 3d8047fa1b feat(trainer): 重构检查点系统支持分布式训练 2026-01-08 15:01:19 +08:00
ViperEkura d21682f97a fix(trainer): 修复检查点回调参数顺序和权重保存选项 2026-01-05 17:08:09 +08:00
ViperEkura eba99e1f5e feat(model): 添加QK归一化和门控注意力支持 2026-01-05 16:14:44 +08:00
ViperEkura fd7ee2895a refactor(paralell): 优化并行设备指定方法 2025-12-26 20:54:33 +08:00
ViperEkura cfa3cf7daa feat(train): 支持分布式训练的优化器与调度器工厂配置 2025-12-22 20:41:03 +08:00
ViperEkura 7623b1e5fd feat(khaosz/data/tokenizer): 优化BPE分词器的预处理和训练配置 2025-12-22 20:02:10 +08:00
ViperEkura 573f041c51 feat(trainer): 支持分布式训练配置与检查点加载优化 2025-12-19 19:34:39 +08:00
ViperEkura eab7a51bb6 feat(parallel): 改进设备策略注册表与并行设置功能 2025-12-19 15:25:31 +08:00
ViperEkura 3ac38a7ebc feat(parallel/device): 引入设备策略注册机制以支持多种后端 2025-12-15 13:58:59 +08:00
ViperEkura 831933fb66 fix(mmap): 修复样本数与键值计算逻辑并增强错误处理 2025-12-15 09:27:29 +08:00
ViperEkura 701fb9bf78 refactor(data): 将内存映射文件加载逻辑移至独立的 MmapFileHander 类 2025-12-15 09:12:42 +08:00
ViperEkura d882f65579 refactor(parallel): 重构parallel模块 2025-12-13 22:16:17 +08:00
ViperEkura a30ddca517 fix(data): 修改 Sampler 的长度计算方式, 避免提前初始化 2025-12-10 18:57:53 +08:00
ViperEkura 8e975017d3 fix(demo): 修复拼写错误 2025-12-10 15:22:26 +08:00
ViperEkura fed4d64cea ci(spell-check): 添加拼写检查工作流 2025-12-10 15:17:59 +08:00
ViperEkura 110efd2a21 fix(trainer): 修复训练上下文构建逻辑并修正拼写错误 2025-12-10 15:02:39 +08:00
ViperEkura 530fb50352 feat(parallel): 重构并重命名并行工具函数以提升灵活性 2025-12-10 14:43:35 +08:00
ViperEkura c86e573195 feat(trainer): 改进模型输入和损失计算中的数据类型精度 2025-12-08 14:10:08 +08:00
ViperEkura 0093ba7bb8 build(requirements): 升级 urllib3 版本从 2.5.0 到 2.6.0 2025-12-08 13:48:50 +08:00
ViperEkura c934210066 fix(trainer): 修复参数传递问题和检查点保存问题 2025-12-08 13:28:11 +08:00
ViperEkura c98b175cd5 refactor(trainer): 优化trainer 结构 2025-12-07 21:23:05 +08:00
ViperEkura 82e65ccc21 fix(tools/train): 修复参数传递错误 2025-12-05 13:53:50 +08:00
ViperEkura d52685facd feat(paralell): 添加分布式训练配置与并行工具支持 2025-12-05 13:52:17 +08:00
ViperEkura d31137a2db feat(config): 重构模型参数状态加载 2025-12-04 20:23:23 +08:00
ViperEkura 6270415590 feat(khaosz/parallel): 添加对多种设备后端的支持并优化并行初始化逻辑 2025-12-03 17:24:32 +08:00
ViperEkura 08c5a52dc8
Merge pull request #15 from ViperEkura/dependabot/pip/fonttools-4.61.0
build(deps): bump fonttools from 4.59.0 to 4.61.0
2025-12-03 16:59:35 +08:00
dependabot[bot] ac1fefb363
build(deps): bump fonttools from 4.59.0 to 4.61.0
Bumps [fonttools](https://github.com/fonttools/fonttools) from 4.59.0 to 4.61.0.
- [Release notes](https://github.com/fonttools/fonttools/releases)
- [Changelog](https://github.com/fonttools/fonttools/blob/main/NEWS.rst)
- [Commits](https://github.com/fonttools/fonttools/compare/4.59.0...4.61.0)

---
updated-dependencies:
- dependency-name: fonttools
  dependency-version: 4.61.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-12-01 22:28:11 +00:00
ViperEkura 8b20982933 refactor(parallel): 重命名并重新组织并行模块文件结构 2025-11-30 17:56:47 +08:00
ViperEkura d5cc9f065d feat(khaosz/parallel): 添加并行训练设置功能 2025-11-30 16:44:04 +08:00
ViperEkura db53cc5001 feat(tools/train): 优化训练参数传递 2025-11-30 13:49:24 +08:00
ViperEkura 3ee84b31a0 feat(data): 重构数据集加载逻辑,修复计数错误 2025-11-28 20:59:24 +08:00
ViperEkura 567c55685e docs(data/dataset): 更新 load_mmap_files 函数的文档 2025-11-28 20:27:57 +08:00
ViperEkura 1f5cba889b fix(data): 修复数据加载模块中的拼写错误并优化内存映射加载逻辑 2025-11-28 20:21:53 +08:00
ViperEkura 019bfe4e05 fix(data/sampler): 修正拼写错误并增强采样器功能 2025-11-27 19:43:36 +08:00
ViperEkura 36b410384b fix(data/sampler): 增加sampler边界情况处理 2025-11-27 19:32:40 +08:00
ViperEkura 09963a3beb refactor(data): 重构数据模块结构并优化可恢复采样器实现 2025-11-27 18:16:35 +08:00
ViperEkura 5daf63a7a4 fix(model): 修复加载状态字典时的键存在性检查 2025-11-25 21:03:10 +08:00
ViperEkura fb85aaf6a6 fix(parallel): 修改列并行线性层结果聚合方式 2025-11-21 13:37:08 +08:00
ViperEkura 6fb6a15e81 feat(model): 添加并行线性层模型支持 2025-11-21 12:54:59 +08:00
ViperEkura d9ff662e3a fix(model): 调整 KV Cache 的维度顺序以匹配新的索引逻辑 2025-11-19 18:26:15 +08:00
ViperEkura e12ed0a72b fix(khaosz): 为其他模组添加init文件 2025-11-19 18:25:51 +08:00
ViperEkura 3bf2468905 fix(tools): 修正训练脚本中的嵌入层参数分组判断条件 2025-11-19 17:47:33 +08:00
ViperEkura 3c7ed84516 test(test_tie_weight): 添加测试以验证权重绑定后的数据修改行为 2025-11-19 17:47:22 +08:00
ViperEkura 1c3a693d79 feat(model): 优化RMSNorm实现方式 2025-11-15 13:54:04 +08:00
ViperEkura e99ef9d6d8 refactor(demo): 重构示例脚本目录结构 2025-11-10 21:35:04 +08:00
ViperEkura 4c289e974a refactor(tools): 将工具脚本移动到tools目录下 2025-11-10 21:26:02 +08:00
ViperEkura f31bf5a959 test(transformer): 更新 tie_weight 相关测试逻辑 2025-11-09 17:23:33 +08:00
ViperEkura 7a21f5d72e build(setup): 更新版本号并调整 Python 版本要求 2025-11-09 16:40:20 +08:00
ViperEkura 0b45e8666e fix(scripts): 修复stream_chat.py中的拼写错误 2025-11-09 16:30:24 +08:00
ViperEkura 6f3386f02c fix(transformer): 优化state_dict 处理逻辑, 优化attention_mask的处理方式 2025-11-09 16:25:17 +08:00
ViperEkura d25202a329 feat(model): 实现旋转位置编码缓存动态扩展 2025-11-09 14:35:29 +08:00
ViperEkura 254ec934be feat(transformer): 简化权重绑定逻辑并增加测试单元 2025-11-07 15:14:54 +08:00
ViperEkura 7e5ecf3b7d refactor(config): 重命名 TransformerConfig 为 ModelConfig 2025-11-07 07:31:12 +08:00
ViperEkura 66a551217e refactor(generator): 优化生成逻辑 2025-11-07 07:24:00 +08:00
ViperEkura bdc3f4dc63 feat(module): 重构旋转位置编码实现以提升性能和可读性 2025-11-06 17:52:47 +08:00
ViperEkura 805773c7fe docs(transformer): 更新process_attention_mask函数文档 2025-11-05 23:41:11 +08:00
ViperEkura 7ccc4ab9ac fix(model): 修复加载状态字典时的权重共享问题 2025-11-05 23:38:45 +08:00
ViperEkura 69d9374f51 feat(model): 添加 tie_weight 配置选项并优化模型模块实现 2025-11-05 23:26:57 +08:00
ViperEkura b260f5581d fix(benchmark): 优化 KV 缓存初始化并更正基准测试类型标识 2025-11-05 15:44:29 +08:00
ViperEkura 0a754e3341 feat(scripts): 调整文本生成参数以提升多样性 2025-11-05 13:56:58 +08:00
ViperEkura 144b9598ad feat(model): 添加 Linear 和 Embedding 模块的自定义参数初始化支持 2025-10-31 22:43:12 +08:00
ViperEkura 877669b799 feat(inference): 添加generate_loop方法并优化KVCacheManager初始化 2025-10-31 21:15:15 +08:00
ViperEkura cdb47a62dc test: 统一重构数据集和调度器测试模块 2025-10-31 20:24:01 +08:00
ViperEkura e86328b753 fix(tokenizer): 修复stop_ids属性返回错误的token ID列表 2025-10-31 19:19:38 +08:00
ViperEkura 5d3799b715 refactor(data): 修改变量命名方式 2025-10-30 16:32:25 +08:00
ViperEkura 6a3135f401 fix(data_util): 修复数据集索引计算逻辑并提取通用方法 2025-10-29 20:58:33 +08:00
ViperEkura 12850d403c fix(config): 修改Checkpoint类中tokenizer和config字段的默认值初始化方式 2025-10-29 13:24:20 +08:00
ViperEkura bad6243b53 fix(train): 更新训练函数参数传递方式 2025-10-29 13:23:53 +08:00
ViperEkura f2448a5147 feat(benchmark): 优化KV缓存初始化逻辑 2025-10-29 12:41:32 +08:00
ViperEkura 46b2a0f86f feat(train): 添加 max_len 和 step_size 参数支持 2025-10-29 12:32:17 +08:00
ViperEkura d94fc5a87a feat(data, inference): 使用chatML格式 2025-10-29 12:02:43 +08:00
ViperEkura 38b2725cd1 feat(KVCacheManager): 优化KV缓存结构为元组形式以提升性能 2025-10-29 12:01:28 +08:00
ViperEkura bc5ef72001 fix(config): 修正 SGDRScheduleConfig 类名拼写错误 2025-10-20 18:21:46 +08:00
ViperEkura e051005334 test(test_module): 更新测试用例以使用新的generate_iterator接口 2025-10-20 13:52:31 +08:00
ViperEkura 0db046f8d9 feat(khaosz/trainer): 更新梯度裁剪回调 2025-10-20 13:30:26 +08:00
ViperEkura 05b012820b refactor(khaosz): 重构模块导出结构并重命名主模块文件 2025-10-20 13:07:02 +08:00
ViperEkura e72e244df6 feat(inference): 实现采样策略并优化生成器逻辑 2025-10-20 13:00:41 +08:00
ViperEkura 98efca7b9d feat(trainer): 添加训练起始轮次和批次配置支持 2025-10-19 21:47:10 +08:00
ViperEkura 613edd7a14 test(early_stopping, train_strategy): 更新测试配置以提高稳定性 2025-10-18 22:07:11 +08:00
ViperEkura 622982364b fix(trainer): 修复检查点加载逻辑 2025-10-18 21:45:23 +08:00
ViperEkura b67bc9865d refactor(trainer): 重构学习率调度器实现并分离配置与工厂逻辑 2025-10-18 16:42:37 +08:00
ViperEkura c51b203fde refactor(khaosz): 重构项目结构 2025-10-18 13:56:59 +08:00
ViperEkura 8434c19923 fix(khaosz/trainer): 修复数据获取中的索引范围错误和参数传递问题 2025-10-09 19:53:52 +08:00
ViperEkura 68a15005cb feat(train.py): 支持从检查点恢复训练并优化数据加载配置 2025-10-07 22:02:50 +08:00
ViperEkura efbe3de9d3 fix(khaosz/trainer/data_util): 修复数据集索引范围错误 2025-10-07 20:04:45 +08:00
ViperEkura 12793bc2d3 feat(khaosz/trainer): 新增梯度统计工具函数并重构训练回调机制 2025-10-07 13:03:32 +08:00
ViperEkura 0764cb8296 fix(khaosz/trainer/train_callback): 修复基类函数命名错误 2025-10-07 11:43:51 +08:00
ViperEkura 57cd7b921e feat(khaosz/trainer): 改进训练循环中的损失归一化处理 2025-10-06 20:17:47 +08:00
ViperEkura c1bf22b6ec refactor(khaosz/trainer): 使用 TrainContext 替代 kwargs 传递训练上下文 2025-10-06 20:12:08 +08:00
ViperEkura f9b6331ad7 refactor(khaosz/core/parameter): 修改参数名称 2025-10-06 20:11:46 +08:00
ViperEkura 183f481692 build(khaosz): 更新版本号至1.3.0 2025-10-06 17:12:12 +08:00
ViperEkura ec0c054d26 test(early_stopping): 移除未使用的torch.utils.data导入 2025-10-06 17:10:10 +08:00
ViperEkura 4ffa7454f2 feat(strategy): 支持模型输入可调用对象并优化损失计算 2025-10-06 17:08:56 +08:00
ViperEkura 8c9e973179 fix(train.py): 修复数据集加载时的参数传递问题 2025-10-06 16:44:02 +08:00
ViperEkura fc98d9b7e6 refactor(khaosz/trainer): 移除未使用的导入模块 2025-10-04 21:45:53 +08:00
ViperEkura 9d5aa952e0 feat(tests): 重构测试环境, 便于pickle 序列化 2025-10-04 21:31:39 +08:00
ViperEkura 2ccd7bd583 refactor(khaosz/trainer): 重构训练器模块结构以提升可维护性 2025-10-04 21:31:15 +08:00
ViperEkura e7d29ca2d5 feat(tests): 改进测试环境配置与设备管理 2025-10-04 12:12:42 +08:00
ViperEkura 465a1a9373 refactor(khaosz/tainer): 修改设备参数传递发生阶段 2025-10-04 12:12:21 +08:00
ViperEkura 240ee00221 feat(khaosz/trainer): 引入 TrainContext 和 TrainContextBuilder 优化训练上下文管理 2025-10-03 22:42:11 +08:00
ViperEkura 6e1a497c04 test(sampler): 删除冗余的训练恢复测试用例 2025-10-03 22:18:31 +08:00
ViperEkura 85aeec9e55 test(conftest): 添加matplotlib后端设置以避免GUI问题 2025-10-03 22:11:54 +08:00
ViperEkura 9a452dd34e fix(khaosz/trainer/data_util.py): 修复 RandomSampler 中迭代计数器位置错误 2025-10-03 22:08:28 +08:00
ViperEkura 28b01220b6 test(trainer): 拆分测试文件 2025-10-03 22:08:11 +08:00
148 changed files with 17156 additions and 4555 deletions

9
.dockerignore Normal file
View File

@ -0,0 +1,9 @@
# Ignore everything
*
# Allow necessary files
!astrai/
!scripts/
!assets/
!pyproject.toml
!README.md

19
.gitattributes vendored Normal file
View File

@ -0,0 +1,19 @@
# Auto detect text files
* text=auto
# Files that MUST use LF (Unix/Linux execution)
*.sh text eol=lf
*.py text eol=lf
*.md text eol=lf
*.yml text eol=lf
Dockerfile text eol=lf
.dockerignore text eol=lf
.gitignore text eol=lf
.gitattributes text eol=lf
# Windows scripts - use CRLF
*.bat text eol=crlf
*.cmd text eol=crlf
*.ps1 text eol=crlf

27
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@ -0,0 +1,27 @@
---
name: Bug report
about: Create a report to help us improve
title: "[BUG]"
labels: bug
assignees: ''
---
## Description
A clear and concise description of what the bug is.
## Steps to Reproduce
1. ...
2. ...
3. ...
## Expected Behavior
What you expected to happen.
## Actual Behavior
What actually happened.
## Environment
- Python version:
- AstrAI version (or commit hash):
- Operating System:
- GPU (if applicable):
- CUDA/cuDNN version (if applicable):
## Additional Context
Add any other context, screenshots, or logs here.

10
.github/ISSUE_TEMPLATE/custom.md vendored Normal file
View File

@ -0,0 +1,10 @@
---
name: Custom issue template
about: Describe this issue template's purpose here.
title: ''
labels: ''
assignees: ''
---

View File

@ -0,0 +1,19 @@
---
name: Feature request
about: Suggest an idea for this project
title: "[FEAT]"
labels: ''
assignees: ''
---
## Description
A clear and concise description of the feature you'd like to see.
## Problem Statement
What problem does this feature solve? Why is it needed?
## Proposed Solution
Describe the solution you'd like. Include any design ideas, API changes, or implementation details.
## Alternatives Considered
Describe any alternative solutions or features you've considered.
## Additional Context
Add any other context, screenshots, or references here.

26
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@ -0,0 +1,26 @@
## Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context.
Fixes # (issue number)
## Type of Change
Please delete options that are not relevant.
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Documentation update
- [ ] Other (please describe):
## How Has This Been Tested?
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
## Checklist:
- [ ] My code follows the style guidelines of this project (run `ruff format .` and `ruff check . --select I`)
- [ ] I have performed a self-review of my own code
- [ ] Code is self-documenting (no unnecessary comments)
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged and published in downstream modules

50
.github/workflows/docker.yml vendored Normal file
View File

@ -0,0 +1,50 @@
name: Build and Push Docker Image
on:
push:
tags:
- 'v*'
jobs:
build:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ghcr.io/${{ github.repository }}
tags: |
type=ref,event=tag
type=raw,value=latest
- name: Build and push
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

31
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Lint
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Check formatting with ruff
run: |
ruff format --check .
- name: Check import sorting
run: |
ruff check . --select I

31
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Tests
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[dev]
- name: Run tests with pytest
run: |
python -m pytest tests/ -v

30
.gitignore vendored
View File

@ -1,13 +1,23 @@
# cache # Ignore everything
__pycache__ *
.pytest_cache
# params # Allow directories to be traversed
params/* !*/
# vscode file # Allow specific file types and root files
.vscode !*.py
!*.sh
# build file # Allow GitHub files
build !/.github/**
*.egg-info
# Allow root files
!/.gitattributes
!/.dockerignore
!/Dockerfile
!/docker-compose.yml
!/assets/**
!/CONTRIBUTING.md
!/LICENSE
!/pyproject.toml
!/README.md

100
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,100 @@
# Contributing to AstrAI
Thank you for your interest in contributing! This document provides step-by-step guidelines.
## Quick Start
```bash
git clone https://github.com/your-username/AstrAI.git
cd AstrAI
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
```
## Before You Commit
Run the following checks **in order** — CI will reject if any fail.
### 1. Format
```bash
ruff format .
```
> **Note**: `ruff format` may rename parameters (e.g. `mask``attn_mask`).
> Always review the diff after formatting.
### 2. Import sorting
```bash
ruff check . --select I
```
If this fails, **manually fix** import ordering (ruff does not auto-fix in this project's CI):
```bash
ruff check . --select I --fix .
ruff format . # re-format after fix
```
### 3. Run tests
```bash
python -u -m pytest tests/ -v
```
> Failed tests may leave orphan tempdirs under `%TEMP%`. Clean them manually if needed.
### 4. (Optional) Full pre-commit check
If you have Git Bash available:
```bash
bash scripts/pre_commit.sh
```
This runs format check, import sort check, and tests in one go.
## Commit Style
```
fix/feat/chore/docs/refactor/perf/test/style/ci/build/revert : short description (~50 chars)
- bullet point body (each ~60 chars)
```
- **Type** must be one of: `fix`, `feat`, `chore`, `docs`, `refactor`, `perf`, `test`, `style`, `ci`, `build`, `revert`.
- **Subject line** ends with no period.
- **Body** uses bullet points starting with `-`.
- No `(scope)` parentheses.
## Common Issues
| Problem | Cause | Fix |
|---------|-------|-----|
| `ruff check --select I` fails | Wrong import order | `ruff check . --select I --fix .` then `ruff format .` |
| `ruff format` changed many files | Not formatted before commit | Review diff carefully before staging |
| Pre-commit hook rejects | Tests or lint failed | Fix individually, do not `--no-verify` |
| Tests fail with tempdir left | Test crash | Clean `%TEMP%` manually |
## Submitting Changes
1. Fork the repo.
2. Create a feature branch: `git checkout -b feat/my-feature`
3. Make changes following the steps above.
4. Commit with the commit style above.
5. Push: `git push origin feat/my-feature`
6. Open a Pull Request against `main`.
## Code Review
- All PRs are reviewed. We may request changes.
- CI runs `ruff format --check .` then `ruff check . --select I` (no `--fix` in CI).
- Ensure all tests pass.
## License
By contributing, you agree that your contributions will be licensed under the [GPL-3.0 License](LICENSE).
---
Questions? Ask in [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions) or open an issue.

55
Dockerfile Normal file
View File

@ -0,0 +1,55 @@
# AstrAI Dockerfile - Multi-stage Build (Optimized)
# Build stage - use base image with minimal build tools
FROM ubuntu:24.04 AS builder
WORKDIR /app
# Install Python 3.12 and minimal build dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.12 \
python3.12-dev \
python3.12-venv \
gcc \
g++ \
&& rm -rf /var/lib/apt/lists/*
# Create isolated virtual environment
RUN python3.12 -m venv --copies /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Copy source code and install (deps read from pyproject.toml)
COPY astrai/ ./astrai/
COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip \
&& pip install --no-cache-dir . \
--extra-index-url https://download.pytorch.org/whl/cu126
# Production stage
FROM ubuntu:24.04 AS production
WORKDIR /app
# Install Python 3.12 runtime and healthcheck dependency
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.12 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Copy virtual environment from builder
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Copy application code
COPY astrai/ ./astrai/
COPY scripts/ ./scripts/
COPY assets/ ./assets/
COPY pyproject.toml .
COPY README.md .
# Create non-root user
RUN useradd -m astrai && chown -R astrai:astrai /app
USER astrai
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1

811
LICENSE
View File

@ -1,201 +1,674 @@
Apache License GNU GENERAL PUBLIC LICENSE
Version 2.0, January 2004 Version 3, 29 June 2007
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
1. Definitions. Preamble
"License" shall mean the terms and conditions for use, reproduction, The GNU General Public License is a free, copyleft license for
and distribution as defined by Sections 1 through 9 of this document. software and other kinds of works.
"Licensor" shall mean the copyright owner or entity authorized by The licenses for most software and other practical works are designed
the copyright owner that is granting the License. to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
"Legal Entity" shall mean the union of the acting entity and all When we speak of free software, we are referring to freedom, not
other entities that control, are controlled by, or are under common price. Our General Public Licenses are designed to make sure that you
control with that entity. For the purposes of this definition, have the freedom to distribute copies of free software (and charge for
"control" means (i) the power, direct or indirect, to cause the them if you wish), that you receive source code or can get it if you
direction or management of such entity, whether by contract or want it, that you can change the software or use pieces of it in new
otherwise, or (ii) ownership of fifty percent (50%) or more of the free programs, and that you know you can do these things.
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity To protect your rights, we need to prevent others from denying you
exercising permissions granted by this License. these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
"Source" form shall mean the preferred form for making modifications, For example, if you distribute copies of such a program, whether
including but not limited to software source code, documentation gratis or for a fee, you must pass on to the recipients the same
source, and configuration files. freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
"Object" form shall mean any form resulting from mechanical Developers that use the GNU GPL protect your rights with two steps:
transformation or translation of a Source form, including but (1) assert copyright on the software, and (2) offer you this License
not limited to compiled object code, generated documentation, giving you legal permission to copy, distribute and/or modify it.
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or For the developers' and authors' protection, the GPL clearly explains
Object form, made available under the License, as indicated by a that there is no warranty for this free software. For both users' and
copyright notice that is included in or attached to the work authors' sake, the GPL requires that modified versions be marked as
(an example is provided in the Appendix below). changed, so that their problems will not be attributed erroneously to
authors of previous versions.
"Derivative Works" shall mean any work, whether in Source or Object Some devices are designed to deny users access to install or run
form, that is based on (or derived from) the Work and for which the modified versions of the software inside them, although the manufacturer
editorial revisions, annotations, elaborations, or other modifications can do so. This is fundamentally incompatible with the aim of
represent, as a whole, an original work of authorship. For the purposes protecting users' freedom to change the software. The systematic
of this License, Derivative Works shall not include works that remain pattern of such abuse occurs in the area of products for individuals to
separable from, or merely link (or bind by name) to the interfaces of, use, which is precisely where it is most unacceptable. Therefore, we
the Work and Derivative Works thereof. have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
"Contribution" shall mean any work of authorship, including Finally, every program is threatened constantly by software patents.
the original version of the Work and any modifications or additions States should not allow patents to restrict development and use of
to that Work or Derivative Works thereof, that is intentionally software on general-purpose computers, but in those that do, we wish to
submitted to Licensor for inclusion in the Work by the copyright owner avoid the special danger that patents applied to a free program could
or by an individual or Legal Entity authorized to submit on behalf of make it effectively proprietary. To prevent this, the GPL assures that
the copyright owner. For the purposes of this definition, "submitted" patents cannot be used to render the program non-free.
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity The precise terms and conditions for copying, distribution and
on behalf of whom a Contribution has been received by Licensor and modification follow.
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of TERMS AND CONDITIONS
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of 0. Definitions.
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the "This License" refers to version 3 of the GNU General Public License.
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or "Copyright" also means copyright-like laws that apply to other kinds of
Derivative Works a copy of this License; and works, such as semiconductor masks.
(b) You must cause any modified files to carry prominent notices "The Program" refers to any copyrightable work licensed under this
stating that You changed the files; and License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
(c) You must retain, in the Source form of any Derivative Works To "modify" a work means to copy from or adapt all or part of the work
that You distribute, all copyright, patent, trademark, and in a fashion requiring copyright permission, other than the making of an
attribution notices from the Source form of the Work, exact copy. The resulting work is called a "modified version" of the
excluding those notices that do not pertain to any part of earlier work or a work "based on" the earlier work.
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its A "covered work" means either the unmodified Program or a work based
distribution, then any Derivative Works that You distribute must on the Program.
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and To "propagate" a work means to do anything with it that, without
may provide additional or different license terms and conditions permission, would make you directly or secondarily liable for
for use, reproduction, or distribution of Your modifications, or infringement under applicable copyright law, except executing it on a
for any such Derivative Works as a whole, provided Your use, computer or modifying a private copy. Propagation includes copying,
reproduction, and distribution of the Work otherwise complies with distribution (with or without modification), making available to the
the conditions stated in this License. public, and in some countries other activities as well.
5. Submission of Contributions. Unless You explicitly state otherwise, To "convey" a work means any kind of propagation that enables other
any Contribution intentionally submitted for inclusion in the Work parties to make or receive copies. Mere interaction with a user through
by You to the Licensor shall be under the terms and conditions of a computer network, with no transfer of a copy, is not conveying.
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade An interactive user interface displays "Appropriate Legal Notices"
names, trademarks, service marks, or product names of the Licensor, to the extent that it includes a convenient and prominently visible
except as required for reasonable and customary use in describing the feature that (1) displays an appropriate copyright notice, and (2)
origin of the Work and reproducing the content of the NOTICE file. tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
7. Disclaimer of Warranty. Unless required by applicable law or 1. Source Code.
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, The "source code" for a work means the preferred form of the work
whether in tort (including negligence), contract, or otherwise, for making modifications to it. "Object code" means any non-source
unless required by applicable law (such as deliberate and grossly form of a work.
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing A "Standard Interface" means an interface that either is an official
the Work or Derivative Works thereof, You may choose to offer, standard defined by a recognized standards body, or, in the case of
and charge a fee for, acceptance of support, warranty, indemnity, interfaces specified for a particular programming language, one that
or other liability obligations and/or rights consistent with this is widely used among developers working in that language.
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
APPENDIX: How to apply the Apache License to your work. The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
To apply the Apache License to your work, attach the following The Corresponding Source need not include anything that users
boilerplate notice, with the fields enclosed by brackets "[]" can regenerate automatically from other parts of the Corresponding
replaced with your own identifying information. (Don't include Source.
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner] The Corresponding Source for a work in source code form is that
same work.
Licensed under the Apache License, Version 2.0 (the "License"); 2. Basic Permissions.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
Unless required by applicable law or agreed to in writing, software You may make, run and propagate covered works that you do not
distributed under the License is distributed on an "AS IS" BASIS, convey, without conditions so long as your license otherwise remains
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. in force. You may convey covered works to others for the sole purpose
See the License for the specific language governing permissions and of having them make modifications exclusively for you, or provide you
limitations under the License. with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.

488
README.md
View File

@ -1,333 +1,255 @@
![image-20250306182014120](/assets/images/project_logo_clipped.png) <div align="center">
<div style="display: flex; flex-direction: column; align-items: center; justify-content: center; text-align: center; font-size: 16px; font-weight: bold; margin-top: 50px;">
<div> <img src="assets/images/logo.png" width="auto" alt="Logo">
<a href="#english" style="text-decoration: none; margin: 0 10px; color: blue;">English</a> | <p>
<a href="#chinese" style="text-decoration: none; margin: 0 10px; color: blue;">中文</a> <strong>A lightweight Transformer training & inference framework</strong>
</div> </p>
<h1 style="margin: 20px 0 0 0; font-size: 2.5em; font-weight: bold;">KHAOSZ </h1>
</div> </div>
<h2 id="english">English Version</h2> <div align="center">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
</div>
<br>
This is a Chinese-English bilingual Transformer model supporting both languages. It contains model configurations and training workflows, completing training by loading parameters defined in `param_path/config.json`. The training script `train.py` parses command-line arguments, including dataset root directory, number of training epochs, batch size, checkpoint interval, and checkpoint directory. <div align="center">
<a href="#english">English</a>
<a href="assets/docs/README-zh-CN.md">中文</a>
<a href="https://github.com/ViperEkura/AstrAI/issues">Issue Tracker</a>
<a href="https://github.com/ViperEkura/AstrAI/discussions">Discussions</a>
<a href="https://huggingface.co/ViperEk/">HuggingFace</a>
</div>
**Model Download Options (Choose One):** <br>
1. Visit [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) to access **Files and versions** ## 📖 Table of Contents
2. Run `scripts/download.py` to download parameters
**Demo Video:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd) - [Features](#features)
- [Quick Start](#quick-start)
- [Documentation](#documentation)
- [Contributing](#contributing)
- [Community](#community)
- [License](#license)
Training dataset sources are listed in the **Model Card** section of the HuggingFace download link. ---
**License:** Code follows Apache-2.0 protocol. Please credit the source code when used. <a id="english"></a>
## English
- **📊 Device Selection:** Code defaults to CUDA training ### Features
- **🌐 Performance Optimization:** `dtype=torch.bfloat16` is enabled to accelerate training and reduce memory usage. Ensure hardware supports this feature.
- **🤖 Language Support:** Model supports Chinese and English training. The BBPE tokenizer was trained without multilingual text, so OOV (out-of-vocabulary) issues are minimized for these languages but may exist for others.
### 📌 Training Guide - 🚀 **High Performance**: Optimized for both training and inference with efficient parallelization.
- 🔧 **Flexible**: Support for seq/sft/dpo/grpo training, customizable model architectures.
- 💡 **Easy to Use**: Simple API with comprehensive examples and demos.
- 📦 **Lightweight**: Minimal dependencies, easy to deploy.
- 🔬 **ResearchFriendly**: Modular design, easy to experiment with new ideas.
- 🤗 **HuggingFace-Style API**: AutoModel/AutoTokenizer APIs inspired by HuggingFace for easy model and tokenizer loading.
- 🔌 **Dual API Compatibility**: Supports both OpenAI and Anthropic chat completion APIs out of the box.
To train this Transformer model, follow these steps: ### Quick Start
**(1). Prepare Dataset:** #### Installation
Place datasets in the designated root directory. Files should be text documents in Chinese, English, or mixed. Format should align with model input requirements - preferably pre-tokenized token_ids stored as `torch.Tensor` (using `torch.Tensor` saves memory compared to Python lists, which default to 64-bit precision).
**(2). Install Dependencies:**
```bash ```bash
pip install -r requirements.txt git clone https://github.com/ViperEkura/AstrAI.git
pip install . cd AstrAI
pip install -e .
``` ```
**(3). Run Training Script:** For development dependencies:
```bash ```bash
python train.py \ pip install -e ".[dev]"
--train_type=train_type[seq, sft, dpo] \
--data_root_path=/path/to/dataset \
--param_path=/path/to/param_path \
--n_epoch=5 \
--batch_size=8 \
--max_lr=2e-4 \
--checkpoint_interval=10000 \
--checkpoint_dir=checkpoints
``` ```
**Parameters Explanation:** #### Download Pre-trained Model
- `--train_type`: Training type (seq, sft, dpo)
- `--data_root_path`: Root directory of the dataset
- `--param_path`: Path to the model training parameters
- `--n_epoch`: Total number of training epochs
- `--batch_size`: Batch size
- `--accumulation_steps`: Number of batches per training step
- `--warmup_steps`: Number of warmup steps
- `--max_lr`: Maximum learning rate (using warmup + cosine decay)
- `--checkpoint_interval`: Checkpoint saving interval
- `--checkpoint_dir`: Directory to save checkpoints
- `--resume_dir`: Resume training from the specified path
Training logs will be saved in `train_log.txt`. Checkpoints will be saved in the specified directory for resuming training or evaluation. Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
### 👉 Usage Guide
**(1). Chatting with the Model:**
Open `chat.py` or use streaming/non-streaming interfaces:
**Streaming Output:**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
response_size = 0
for response, history in model.stream_generate(
query=query,
history=history,
temperature=0.85,
top_p=0.95,
top_k=50
):
print(response[response_size:], end="")
response_size = len(response)
```
**Non-streaming Output:**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
response = model.generate(
query=query,
history=history,
temperature=0.85,
top_p=0.95,
top_k=50
)
print(response)
```
**(2) Retrieval-Augmented Generation (RAG):**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
retrieved_content = model.retrieve_generate(
query=query,
retrieve_top_k=5,
temperature=0.6,
top_k=30,
top_p=0.95
)
print(retrieved_content)
```
### 📌 Model Specifications
This model is based on a 24-layer Transformer with parameters defined in `config.json`, totaling approximately 1.0 billion (1.0B) parameters.
**Key Design Choices:**
- Weight tying between embedding and final linear layers (standard for small models to save parameters)
- Embedding layer optimization: Without weight tying, a 10,000-word vocabulary would consume ~102M parameters (0.1B)
**Limitations:**
- May struggle with complex language phenomena due to smaller parameter size
- Prone to overfitting on specialized datasets
- Limited multilingual capabilities
**Advantages:**
- Runs efficiently on lower-spec hardware
- Shorter training time compared to larger models
**Training Pipeline:**
The model has completed pre-training + SFT (Supervised Fine-Tuning) + DPO (Direct Preference Optimization) workflows. All corresponding training code is included in the repository.
<h2 id="chinese">中文版本</h2>
这是一个支持中英文双语的 Transformer 模型,能够处理两种语言。模型包含配置文件和训练流程,通过加载 `param_path/config.json` 中定义的参数完成训练。训练脚本 `train.py` 支持命令行参数解析包括数据集根目录、训练轮数epochs、批量大小batch size、检查点保存间隔、检查点目录等。
**模型下载选项(任选其一):**
1. 访问 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 查看 **Files and versions**
2. 运行 `scripts/download.py` 下载模型参数
**演示视频:** [bilibili](https://www.bilibili.com/video/BV1z5RPYHEkd)
训练数据来源请参见 HuggingFace 下载页面中的 **Model Card** 部分。
**许可证:** 代码遵循 Apache-2.0 协议,使用时请注明出处。
- **📊 设备选择:** 默认使用 CUDA 进行训练
- **🌐 性能优化:** 启用 `dtype=torch.bfloat16` 以加速训练并减少内存占用,请确保硬件支持该特性
- **🤖 语言支持:** 模型支持中文和英文训练。由于 BBPE 分词器未使用多语言文本训练,因此中英文的 OOV未登录词问题较少其他语言可能存在 OOV 问题
### 📌 训练指南
要训练该 Transformer 模型,请按照以下步骤操作:
#### **(1). 准备数据集:**
将数据集放置在指定的根目录下。文件应为包含中文、英文或混合文本的文本文档。格式应符合模型输入要求——建议使用预分词后的 `token_ids` 并以 `torch.Tensor` 格式保存(使用 `torch.Tensor` 相比 Python 列表更节省内存,列表默认为 64 位精度)。
#### **(2). 安装依赖:**
```bash ```bash
pip install -r requirements.txt python scripts/demo/download.py
pip install .
``` ```
#### **(3). 运行训练脚本:** Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
#### Train a Model
```bash ```bash
python train.py \ export CUDA_VISIBLE_DEVICES=0,1,2,3
--train_type=train_type[seq, sft, dpo] \
--data_root_path=/path/to/dataset \ nohup python scripts/tools/train.py \
--param_path=/path/to/param_path \ --nprocs=4 \
--n_epoch=5 \ --parallel_mode=ddp \
--batch_size=8 \ --train_type=seq \
--max_lr=2e-4 \ --data_root_path=/path/to/dataset \
--checkpoint_interval=10000 \ --param_path=/path/to/model \
--checkpoint_dir=checkpoints --batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
``` ```
**参数说明:** Full reference at [Parameter Guide](assets/docs/params.md).
- `--train_type`: 训练类型seq, sft, dpo
- `--data_root_path`: 数据集根目录
- `--param_path`: 模型训练参数路径
- `--n_epoch`: 总训练轮数
- `--batch_size`: 批量大小
- `--accumulation_steps`: 每个训练步骤的 batch 数量
- `--warmup_steps`: 预热步数warmup steps
- `--max_lr`: 最大学习率(使用预热 + 余弦衰减)
- `--checkpoint_interval`: 检查点保存间隔
- `--checkpoint_dir`: 检查点保存目录
- `--resume_dir`: 从指定路径恢复训练
训练日志将保存在 `train_log.txt` 中。检查点将保存在指定目录,用于恢复训练或评估。 #### Generate Text
```bash
python scripts/tools/generate.py \
### 👉 使用指南 --param_path /path/to/model \
--input_json_file /path/to/input.jsonl \
#### **(1). 与模型对话:** --output_json_file /path/to/output.jsonl
打开 `chat.py` 或使用流式/非流式接口:
**流式输出:**
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir"
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True:
query = input(">> ")
if query == "!exit":
break
response_size = 0
for response, history in model.stream_generate(
query=query,
history=history,
temperature=0.85,
top_p=0.95,
top_k=50
):
print(response[response_size:], end="")
response_size = len(response)
``` ```
**非流式输出:** #### Docker
```python
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir" Build and run with Docker (recommended for GPU environments):
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
history = []
while True: ```bash
query = input(">> ") # Build image
if query == "!exit": docker build -t astrai:latest .
break
# Run with GPU support
response = model.generate( docker run --gpus all -it astrai:latest
query=query,
history=history, # Run with specific GPUs
temperature=0.85, docker run --gpus '"device=0,1"' -it astrai:latest
top_p=0.95,
top_k=50 # Run inference server
) docker run --gpus all -p 8000:8000 astrai:latest \
print(response) python -m scripts.tools.server --port 8000 --device cuda
# Run with volume mount for data
docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker Compose (GPU, default)
docker compose up -d
# Docker Compose (CPU only)
docker compose --profile cpu up -d
``` ```
#### **(2). 基于检索的生成RAG** > **Note**: `--gpus all` is required for CUDA support. Without it, `torch.cuda.is_available()` will return `False`.
```python #### Start HTTP Server
import torch
from khaosz import Khaosz
model_dir = "your_model_parameter_dir" Start the inference server with OpenAI and Anthropic-compatible HTTP API:
model = Khaosz(model_dir).to(device='cuda', dtype=torch.bfloat16)
retrieved_content = model.retrieve_generate( ```bash
query=query, python -m scripts.tools.server --port 8000 --device cuda
retrieve_top_k=5,
temperature=0.6,
top_k=30,
top_p=0.95
)
print(retrieved_content)
``` ```
Make requests:
```bash
# OpenAI-compatible
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
### 📌 模型规格说明(重复部分) # OpenAI-compatible streaming
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Tell a story"}],
"stream": true,
"max_tokens": 500
}'
该模型基于一个 24 层的 Transformer 架构,参数配置定义在 `config.json` 中,总参数量约为 10 亿1.0B)。 # Anthropic-compatible
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "You are a helpful assistant.",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 512
}'
**关键设计选择:** # Anthropic-compatible streaming with stop sequences
- 在嵌入层embedding与最终线性层之间进行权重绑定weight tying这是小型模型中常见的节省参数量的做法 curl -X POST http://localhost:8000/v1/messages \
- 嵌入层优化:若不进行权重绑定,一个包含 10,000 个词的词汇表将消耗约 1.02 亿0.1B)参数 -H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "Write a story"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["The end"]
}'
**局限性:** # Health check
- 由于参数规模较小,可能在处理复杂语言现象时表现受限 curl http://localhost:8000/health
- 在特定领域的数据集上容易出现过拟合 ```
- 多语言能力有限
**优势:** #### Demo
- 可在低配置硬件上高效运行
- 相较于大型模型,训练时间更短
**训练流程:** Check out the demos in the `scripts/demo/` folder:
该模型已完成预训练pre-training+ 监督微调SFT, Supervised Fine-Tuning+ 直接偏好优化DPO, Direct Preference Optimization的全流程。所有相关的训练代码均已包含在代码库中。
```bash
# Download preprocessed data (required before running demos)
python scripts/demo/download.py
# Interactive streaming chat
python scripts/demo/stream_chat.py
# Batch generation
python scripts/demo/generate_batch.py
# Autoregressive generation
python scripts/demo/generate_ar.py
```
Watch a video walkthrough on [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6).
### Documentation
| Document | Description |
|----------|-------------|
| [Parameter Guide](./assets/docs/params.md) | Training & inference parameters |
| [Architecture](./assets/docs/architecture.md) | System architecture, class diagram & design patterns |
| [Training](./assets/docs/training.md) | Training loop, strategies & formulas |
| [Inference](./assets/docs/inference.md) | KVCache, continuous batching, sampling & HTTP API |
| [Data Flow](./assets/docs/dataflow.md) | Data pipeline, storage backends & dataset architecture |
| [Preprocessing](./assets/docs/preprocessing.md) | Declarative JSON-driven data preprocessing |
### Contributing
We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
1. Fork the repository.
2. Create a feature branch.
3. Commit your changes.
4. Open a Pull Request.
For major changes, please open an issue first to discuss what you would like to change.
### Community
- **GitHub Issues**: [Issue Tracker](https://github.com/ViperEkura/AstrAI/issues)
- **Discussions**: [GitHub Discussions](https://github.com/ViperEkura/AstrAI/discussions)
- **HuggingFace**: [Model Hub](https://huggingface.co/ViperEk)
### License
This project is licensed under the [GPL-3.0 License](LICENSE).
---
<div align="center">
<em>A lightweight Transformer framework designed for both high performance and ease of use.</em>
</div>

261
assets/docs/README-zh-CN.md Normal file
View File

@ -0,0 +1,261 @@
<div align="center">
<img src="../images/logo.png" width="auto" alt="Logo">
<div>
<a href="../../README.md">English</a>
<a href="#chinese">中文</a>
</div>
<p>
<strong>轻量级 Transformer 训练与推理框架</strong>
</p>
</div>
<div align="center">
<img src="https://img.shields.io/badge/python-3.12+-blue.svg" alt="python">
<img src="https://img.shields.io/badge/license-GPL--3.0-blue.svg" alt="license">
<img src="https://img.shields.io/github/v/release/ViperEkura/AstrAI?color=76bad9" alt="release">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.stargazers_count&label=stars&suffix=%20stars&color=76bad9" alt="stars">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.github.com%2Frepos%2FViperEkura%2FAstrAI&query=%24.forks_count&label=forks&suffix=%20forks&color=76bad9" alt="forks">
</div>
<br>
<div align="center">
<a href="../../README.md">English</a>
<a href="#chinese">中文</a>
<a href="https://github.com/ViperEkura/AstrAI/issues">问题追踪</a>
<a href="https://github.com/ViperEkura/AstrAI/discussions">讨论区</a>
<a href="https://huggingface.co/ViperEk">HuggingFace</a>
</div>
<br>
## 📖 目录
- [特性](#特性)
- [快速开始](#快速开始)
- [文档](#文档)
- [贡献](#贡献)
- [社区](#社区)
- [许可证](#许可证)
---
<a id="chinese"></a>
## 中文
### 特性
- 🚀 **高性能**: 训练与推理双向优化,高效并行。
- 🔧 **灵活**: 支持 seq/sft/dpo/grpo 多种训练方式,可定制模型架构。
- 💡 **易用**: 简洁的 API 与丰富的示例、演示。
- 📦 **轻量**: 依赖少,部署简单。
- 🔬 **研究友好**: 模块化设计,便于实验新想法。
- 🤗 **HuggingFace 风格 API**: 类 HuggingFace 的 AutoModel/AutoTokenizer 接口,方便加载模型和分词器。
- 🔌 **双 API 兼容**: 同时支持 OpenAI 和 Anthropic 聊天补全 API开箱即用。
### 快速开始
#### 安装
```bash
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e .
```
安装开发依赖:
```bash
pip install -e ".[dev]"
```
#### 下载预训练模型
下载预训练模型权重1B 双语检查点)到 `params/` 目录:
```bash
python scripts/demo/download.py
```
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`
#### 训练模型
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
完整参数列表见[参数说明](./params.md)。
#### 文本生成
```bash
python scripts/tools/generate.py \
--param_path /path/to/model \
--input_json_file /path/to/input.jsonl \
--output_json_file /path/to/output.jsonl
```
#### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境):
```bash
# 构建镜像
docker build -t astrai:latest .
# 启用 GPU 运行
docker run --gpus all -it astrai:latest
# 指定特定 GPU
docker run --gpus '"device=0,1"' -it astrai:latest
# 运行推理服务
docker run --gpus all -p 8000:8000 astrai:latest \
python -m scripts.tools.server --port 8000 --device cuda
# 挂载数据卷
docker run --gpus all -v /path/to/data:/data -it astrai:latest
# Docker ComposeGPU默认
docker compose up -d
# Docker Compose仅 CPU
docker compose --profile cpu up -d
```
> **注意**: 必须使用 `--gpus all` 才能启用 CUDA 支持,否则 `torch.cuda.is_available()` 将返回 `False`
#### 启动 HTTP 服务
启动推理服务器,支持 OpenAI 和 Anthropic 兼容的 HTTP API
```bash
python -m scripts.tools.server --port 8000 --device cuda
```
发起请求:
```bash
# OpenAI 兼容
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "你好"}],
"max_tokens": 512
}'
# OpenAI 兼容流式
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "讲个故事"}],
"stream": true,
"max_tokens": 500
}'
# Anthropic 兼容
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"system": "你是一个乐于助人的助手。",
"messages": [{"role": "user", "content": "你好"}],
"max_tokens": 512
}'
# Anthropic 兼容流式并设置停止序列
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{
"model": "astrai",
"messages": [{"role": "user", "content": "写个故事"}],
"max_tokens": 500,
"stream": true,
"stop_sequences": ["结束"]
}'
# 健康检查
curl http://localhost:8000/health
```
#### 演示
查看 `scripts/demo/` 文件夹中的演示:
```bash
# 下载预处理数据(运行演示前必需)
python scripts/demo/download.py
# 交互式流式聊天
python scripts/demo/stream_chat.py
# 批量生成
python scripts/demo/generate_batch.py
# 自回归生成
python scripts/demo/generate_ar.py
```
观看 [bilibili](https://www.bilibili.com/video/BV1fuLB6yEj6) 上的视频演示。
### 文档
| 文档 | 说明 |
|------|------|
| [参数说明](./params.md) | 训练与推理参数配置 |
| [架构文档](./architecture.md) | 系统架构、类图与设计模式 |
| [训练文档](./training.md) | 训练循环、策略与公式 |
| [推理文档](./inference.md) | KVCache、连续批处理、采样与 HTTP API |
| [数据流程](./dataflow.md) | 数据管道、存储后端与数据集架构 |
| [数据预处理](./preprocessing.md) | 声明式 JSON 驱动数据预处理 |
### 贡献
我们欢迎贡献!请参阅[贡献指南](../../CONTRIBUTING.md)了解详情。
1. Fork 本仓库。
2. 创建功能分支。
3. 提交更改。
4. 发起 Pull Request。
重大更改请先开 issue 讨论。
### 社区
- **GitHub Issues**: [问题追踪](https://github.com/ViperEkura/AstrAI/issues)
- **Discussions**: [GitHub 讨论区](https://github.com/ViperEkura/AstrAI/discussions)
- **HuggingFace**: [模型中心](https://huggingface.co/ViperEk)
### 许可证
本项目采用 [GPL-3.0 许可证](../../LICENSE)。
---
<div align="center">
<em>专为高性能与易用性设计的轻量级 Transformer 框架。</em>
</div>

1208
assets/docs/architecture.md Normal file

File diff suppressed because it is too large Load Diff

64
assets/docs/dataflow.md Normal file
View File

@ -0,0 +1,64 @@
# Data Flow
This document describes the data pipeline: from raw text to model input tensors.
## Overview
```
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Store.load() → Store.fetch() → Dataset → Sampler → DataLoader → Training/Inference
```
## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
StoreFactory.create("h5") → H5Store
StoreFactory.create("bin") → MmapStore
```
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type
| Type | Storage Keys |
|------|-------------|
| `seq` | `sequence` (→ input_ids, target_ids via offset-by-1) |
| `sft` | `sequence`, `loss_mask` |
| `dpo` | `chosen`, `rejected`, `chosen_mask`, `rejected_mask` |
| `grpo` | `prompts`, `responses`, `masks`, `rewards` |
## Dataset Architecture
```
DatasetFactory.load(train_type, load_path, window_size, stride=None, storage_type=None)
→ BaseDataset.load(load_path, storage_type=None)
→ detect_format(load_path)
→ StoreFactory.create(storage_type)
→ Store.load(load_path)
→ H5Store._normalize() / MmapStore._normalize()
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx)
→ get_index(idx) → [begin, end)
→ Store.fetch(begin, end, keys) → Tensor / Dict[str, Tensor]
```
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`, optional). `storage_type` defaults to `None` (auto-detect via `detect_format`).
`Store.fetch(begin, end, keys)` accepts a single key (`str`) returning a `Tensor`, or a list of keys returning `Dict[str, Tensor]`. Internally uses `bisect` across multi-segment tensors. Raises `RuntimeError("Store not loaded")` if called before `load()`.
## Sampler
`ResumableDistributedSampler` supports checkpoint-aware distributed sampling:
- Tracks `start_epoch` / `start_iter` for resume
- Shuffle via `torch.Generator(seed + epoch)`
- Per-replica index slicing for DDP
## DataLoader
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-30

152
assets/docs/inference.md Normal file
View File

@ -0,0 +1,152 @@
# Inference
## KV Cache
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
$$
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
$$
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
## KVCache System
Six classes (plus two helpers) working together:
```
KVCache (facade)
├── PagePool orchestrates page allocation + prefix matching
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
## Continuous Batching
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
```
1. Cleanup → Remove finished tasks, free KV pages
2. Refill → Pop from waiting_queue, task_alloc pages, activate
3. Prefill → Group by (prompt_len, start_pos), run full forward
4. Decode → Pick largest same-position group, single-token forward
```
## Sampling (Strategy Pattern)
```
BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
├── TopPStrategy
└── SamplingPipeline
```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Strategy Pattern)
```python
class ProtocolHandler: # concrete orchestrator
def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops)
else: return await self._handle_non_stream(agen, ctx, stops)
```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult
```
InferenceEngine
├── generate(prompt, stream, ...) → str | List[str] | Generator
├── generate_with_request(req) → same
├── generate_async(prompt, ...) → AsyncGenerator
├── get_stats() → Dict
└── shutdown()
```
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
## HTTP API
```
POST /v1/chat/completions OpenAI
POST /v1/messages Anthropic
GET /health {"status":"ok","model_loaded":true}
GET /stats scheduler statistics
```
### OpenAI
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Response:
```json
{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
}
```
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
### Anthropic
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Supports `stop_sequences` and streaming via `event: content_block_delta`.
### GenerationRequest Parameters
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `top_k` | int | 50 | Top-k count |
| `top_p` | float | 1.0 | Nucleus threshold |
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
| `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output |
## Engine API
```python
# Non-streaming
engine.generate("Hello", stream=False) # -> str
engine.generate(["A", "B"], stream=False) # -> List[str]
# Streaming
engine.generate("Hello", stream=True) # -> Generator[str]
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
# Async
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
print(token)
```
> Document Update Time: 2026-05-30

View File

@ -1,89 +0,0 @@
## 模型介绍
### 1. 模型搭建
本模型采用Transformer架构 使用GQAq_head=24, kv_head=4 机制相较于传统的MHA可以节省KV cache 的显存占用但是目前没有做KV cache通过堆叠24层Transformer实现模型的搭建 参数量为1.0b。Transformer 是自回归模型, 是通过计算前面所有的token的关系得到下一个token的概率分布
![structure](../images/structure.png)
什么是自回归模型呢, 在把句子拆分成token之后, 模型会预测下一个token的概率分布。这意味着模型会根据给定的上下文即已经出现的tokens序列计算出下一个可能的token及其对应的概率。
#### 1. 自回归
假设我们有一个句子被拆分成如下tokens列表
```
["你好", "" "今天", "天气"]
```
接下来模型会基于这个序列预测下一个可能出现的token。这通常以概率分布的形式给出比如
```
-> {"token": "不错", "probability": 0.4}
-> {"token": "晴朗", "probability": 0.2}
-> ......
```
这里“不错”和“晴朗”是两个可能跟随在“天气”之后的tokens并且给出了每个token成为下一个token的可能性大小。
之后我们通过采样通过top_k, top_p, temperature参数调整采样后的结果得到下一个token并且将下一个token加入序列作为输入
```
["你好", "" "今天", "天气", "不错"]
```
之后都是在重复这个流程, 直到遇到控制流程结束的token<|end_of_seqence|>模型停止处理一般模型都会设置控制token 不然模型会一直输出到显存爆炸)。
#### 2. 因果掩码
transformer 中采用注意力机制,输入的形状一般为[bsz, seq_len] 输出为[bsz, seq_lenn_dim] 为了实现预测下一个token 模型的输入和输出必须错开来一个位置。模型预测的target必须错开一个位置 在训练的时候我们也采用错开一个位置的方法
```
sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
注意力得分计算的公式为
$$ s_{ij} = softmax(\frac{q_i^Tk_j}{\sqrt{d_k}}) $$
$$ s_{ij} := s_{ij} + mask_{ij} $$
其中注意力得分代表了模型对两个token之间相似程度的关注程度
对于decoder only结构的模型 为了防止模型从未来的位置偷到信息, 在注意力的计算过程中需要增加掩码我们需要在注意力得分计算之前应用一个掩码。这个掩码通常是一个下三角矩阵对于长度为n的序列它的形状是[n, n]。下面以一个长度为5的序列为例展示如何创建这样的因果掩码矩阵
```
[[0, -inf, -inf, -inf, -inf],
[0, 0, -inf, -inf, -inf],
[0, 0, 0, -inf, -inf],
[0, 0, 0, 0, -inf],
[0, 0, 0, 0, 0]]
```
在这个矩阵中0表示可以注意到的位置而-inf表示应该被掩盖即不应注意到的位置。因为这个句子保证了注意力得分中 $j > i$ 的部分通过softmax 之后由`inf` 变成0 也就是模型不能看到未来的信息
#### 3. 旋转位置编码
旋转位置编码Rotary Position Embedding, RoPE是一种为了解决Transformer模型中缺乏对序列位置信息直接建模的问题而设计的位置编码方法。与传统的位置编码如正弦和余弦函数的位置编码不同RoPE通过将位置信息直接嵌入到查询Query, Q和键Key, K向量中来实现使得模型能够更自然地处理序列中的相对位置关系。
$$ q_i = R_i W_q x_i $$
$$ k_j = R_j W_k x_j $$
$$ q_i^T k_j = (R_i W_q x_i)^T( R_j W_k x_j) = x_i^T W_q^T R_{i-j} W_k x_j $$
其中的 $R_{i-j}$ 控制了模型的不同token 在不同相对距离上注意力的衰减,在 $i - j$ 绝对值越大的时候, 衰减的程度越强, 通过这种方式能让模型学习到相对位置关系, 从而使得模型可以扩展和适应长序列

View File

@ -1,27 +0,0 @@
## kv_cache 实现
根据注意力的计算公式
$$
\begin{align*}
o_i &= \sum_j s_{ij} v_{j} \\
s_{ij} &= \text{softmax}\left( \sum_n \frac{q_{i,n} k_{j,n}}{\sqrt{d_k}} \right)
\end{align*}
$$
由于模型是自回归模型, 我们只用求序列最后一个部分,也就是说 $ i $ 的下标是确定的, 是序列最后一个元素, 我们求的是 $o_{n} $
$$
\begin{align*}
o_n &= \sum_j s_{j}v_{j,n} \\
s_j &= \text{softmax}\left(\sum_n\frac{q_n k_{j,n}}{\sqrt{d_k}} \right)
\end{align*}
$$
如果我们把式子展开
$$
o_n = \sum_j \sum_n \text{softmax}\left(\frac{q_n k_{j,n}}{\sqrt{d_k}}\right)v_{j,n}
$$
以上表达式只有k和v存在长度下标, 而 $q$ 没有, 所以计算过程中 $q$ 的输入是确定的上次输入的最后一个token, 而 $k, v$ 是需要对不同长度的部分进行缓存的同时缓存的时候应该注意位置编码的计算应该在kvcache的计算之前进行否则会存在位置编码的计算错误

100
assets/docs/params.md Normal file
View File

@ -0,0 +1,100 @@
# Parameter Documentation
## Training Parameters
### Basic Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_per_device` | Batch size per device | 1 |
| `--grad_accum_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--warmup_ratio` | Fraction of total steps used for LR warmup | 0.05 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
### Optimizer (AdamW)
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--window_size` | Max input sequence length | model config `max_len` |
| `--stride` | Stride for sliding window over sequences | None |
| `--random_seed` | Random seed for reproducibility | 3407 |
| `--num_workers` | DataLoader worker processes | 4 |
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
### Strategy-specific
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
### Usage Example
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
---
> Document Update Time: 2026-05-24

View File

@ -0,0 +1,283 @@
# Preprocessing Pipeline
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
## Philosophy
| Component | Responsibility |
|-----------|---------------|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
## Quick Start
### SFT Chat
```json
{
"version": 1,
"input": {
"type": "chat",
"messages_key": "messages"
},
"mask": {
"system": "mask",
"user": "mask",
"assistant": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048,
"deduplicate": true
},
"output": {
"domain_key": "source",
"storage_format": "bin",
"max_tokens_per_shard": 100000000
}
}
```
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
### Instruction Tuning
```json
{
"version": 1,
"input": {
"type": "instruction",
"prompt_key": "instruction",
"response_key": "output"
},
"mask": {
"prompt": "mask",
"response": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
},
"output": {
"storage_format": "bin"
}
}
```
Mask splits at the prompt/response field boundary.
### Pretraining
```json
{
"version": 1,
"input": {
"type": "text",
"text_key": "content"
},
"mask": {},
"preprocessing": {
"max_seq_len": 2048,
"min_chars": 50
},
"output": {
"storage_format": "bin"
}
}
```
No mask -- train on all tokens.
### Run
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
## Configuration Reference
### `input`
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
| `text_key` | string | no | `"text"` | JSON key for text field |
### `mask`
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
For instruction mode, keys are `"prompt"` and `"response"`.
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mask` | dict | `{}` | Role/field to action mapping |
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
### `preprocessing`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
### `output`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
## Mask Algorithm
### Chat Mode (role-span tracking)
For each message in the `messages` array:
1. Prepend BOS token (position 0, always masked)
2. Render through the chat template for that single message
3. Encode the rendered text, record token span `(start, end, role)`
4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
5. Fill `loss_mask` from the mask rules
**Multi-turn example**:
```
Data:
[system: "You are helpful."]
[user: "What is 2+2?"]
[assistant: "4"]
[user: "What is 3+3?"]
[assistant: "6"]
Config:
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
Result:
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
mask: 0 0 0 1 0 1
```
Both assistant turns are trained. All system and user tokens are masked.
### Instruction Mode (field boundary)
Encode the prompt and response fields independently, then split the mask at the field boundary.
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
- `"prompt": "train", "response": "mask"` -- the reverse
### Text Mode (no mask)
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
## Output Layout
### Single-Shard (`bin`)
```
output_dir/
__default__/ # when domain_key is null
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
loss_mask.bin # int64 raw bytes
wiki/ # when domain_key="source" and item["source"]="wiki"
meta.json
sequence.bin
loss_mask.bin
```
### Multi-Shard (`bin`)
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
```
output_dir/
__default__/
shard_0000/
meta.json
sequence.bin
loss_mask.bin
shard_0001/
meta.json
sequence.bin
loss_mask.bin
```
`MmapStore` automatically discovers and merges all shards under the domain directory.
### H5 Output
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
```
output_dir/
__default__/
data_0000.h5 # each H5 contains key→dataset groups
data_0001.h5
wiki/
data_0000.h5
```
## Python API Usage
```python
from astrai.preprocessing.pipeline import Pipeline
from astrai.config.preprocess_config import PipelineConfig
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(
config,
["data_part1.jsonl", "data_part2.jsonl"],
output_dir="output/",
tokenizer_path="params"
).run()
```
Or from the CLI:
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
## Extension
Register a custom builder for new formats:
```python
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
@MaskBuilderFactory.register("my_format")
class MyFormatBuilder(BaseMaskBuilder):
def build(self, item: dict, config, tokenizer) -> dict | None:
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
# Return None to skip this item
...
```
Then set `"input": {"type": "my_format"}` in your config.
## Compared to Old Pipeline
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|---|---|
| Configured via constructor arguments | Configured via JSON file |
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
> Document Update Time: 2026-05-30

201
assets/docs/training.md Normal file
View File

@ -0,0 +1,201 @@
# Training
### Autoregression
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
### Causal Mask
```
sequence : [[1, 2, 3, 4, 5, 6]]
input_ids: [[1, 2, 3, 4, 5]]
target_ids: [[2, 3, 4, 5, 6]]
```
Lower-triangular mask prevents attending to future positions:
```
[[0, -inf, -inf, -inf, -inf],
[0, 0, -inf, -inf, -inf],
[0, 0, 0, -inf, -inf],
[0, 0, 0, 0, -inf],
[0, 0, 0, 0, 0]]
```
### Rotary Position Embedding (RoPE)
RoPE embeds position into Q/K vectors via complex rotation:
$$ q_i = R_i W_q x_i, \quad k_j = R_j W_k x_j, \quad q_i^T k_j = x_i^T W_q^T R_{i-j} W_k x_j $$
The complex rotation `freqs_cis` is pre-computed once (`cos, sin` pairs per position). `apply_rotary_emb` multiplies Q/K as complex numbers.
## Training Loop
Two-level loop: **epoch****batch**. Optimizer step fires every `grad_accum_steps` batches.
```
on_train_begin
model.train()
on_epoch_begin
for batch in dataloader:
on_batch_begin
with executor.accumulate(model):
loss = strategy.compute_loss(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
on_batch_end
if executor.sync_gradients:
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
if scheduler:
scheduler.step()
on_epoch_end
on_train_end
```
### Callback Lifecycle
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_epoch_begin` | Start of each epoch | `ProgressBarCallback` |
| `on_batch_begin` | Every batch | — |
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_epoch_end` | End of each epoch | `ProgressBarCallback` |
| `on_error` | On exception during training | `CheckpointCallback`, `MetricLoggerCallback` |
| `on_train_end` | Training ends (always via finally) | `CheckpointCallback`, `MetricLoggerCallback`, `GradientCheckpointingCallback` |
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies
### SEQ (Pre-training)
Next-token cross-entropy with optional label smoothing:
$$
L_{\text{PT}} = -\sum_{t=1}^{T} \log P(x_t \mid x_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`. Optional: `label_smoothing`.
### SFT (Supervised Fine-Tuning)
Masked cross-entropy (`ignore_index=-100`) over response tokens:
$$
L_{\text{SFT}} = -\sum_{t=P+1}^{P+L} \log P(s_t \mid s_{\lt t}; \theta)
$$
Keys: `input_ids`, `target_ids`, `loss_mask`. Optional: `label_smoothing`.
### DPO (Direct Preference Optimization)
Frozen reference model, preference margin via log-ratio:
$$
L_{\text{DPO}} = -\mathbb{E}\left[\log\sigma\left(\beta\log\frac{\pi_\theta(y_w\mid x)}{\pi_{\text{ref}}(y_w\mid x)} - \beta\log\frac{\pi_\theta(y_l\mid x)}{\pi_{\text{ref}}(y_l\mid x)}\right)\right]
$$
Parameters: `beta=0.1`, `reduction="mean"`. Keys: `chosen`, `rejected`, `chosen_mask`, `rejected_mask`.
### GRPO (Group Relative Policy Optimization)
On-policy PPO with group-normalized advantages:
$$
\text{Advantage}_i = \frac{r_i - \mu}{\sigma + \epsilon}
$$
$$
L_{\text{GRPO}} = -\mathbb{E}\left[\min\left(\frac{\pi_\theta}{\pi_{\text{ref}}}A,\; \text{clip}\left(\frac{\pi_\theta}{\pi_{\text{ref}}}, 1-\epsilon, 1+\epsilon\right)A\right)\right] + \lambda \cdot \mathbb{E}\left[(\log\pi_\theta - \log\pi_{\text{ref}})^2\right]
$$
Parameters: `group_size=4`, `clip_eps=0.2`, `kl_coef=0.01`, `sync_interval=200`, `reduction="mean"`.
Keys: `prompts`, `responses`, `masks`, `rewards`.
## LR Schedulers
| Type | Class | Description |
|------|-------|-------------|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
## Gradient Checkpointing
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
```python
from astrai.model.components.decoder_block import DecoderBlock
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
```
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + model.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir, broadcast=False) loads from local disk; set broadcast=True to broadcast metadata from rank-0
```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern)
```python
context = (
TrainContextBuilder(config)
.with_resume_dir(resume_dir)
.build()
)
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
```
- Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, model, device, **kwargs)`
## Training CLI
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--parallel_mode=ddp \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
--grad_accum_steps=8 \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
--ckpt_dir=./checkpoint \
--random_seed=3407 \
--label_smoothing=0.05 \
> out.log 2> err.log &
```
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-30

BIN
assets/images/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 590 KiB

34
astrai/__init__.py Normal file
View File

@ -0,0 +1,34 @@
__version__ = "1.3.7"
__author__ = "ViperEkura"
from astrai.config import (
AutoRegressiveLMConfig,
EncoderConfig,
TrainConfig,
)
from astrai.dataset import DatasetFactory
from astrai.factory import BaseFactory
from astrai.inference import (
GenerationRequest,
InferenceEngine,
)
from astrai.model import AutoModel, AutoRegressiveLM
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"AutoRegressiveLM",
"AutoRegressiveLMConfig",
"EncoderConfig",
"TrainConfig",
"DatasetFactory",
"AutoTokenizer",
"GenerationRequest",
"InferenceEngine",
"Trainer",
"CallbackFactory",
"StrategyFactory",
"SchedulerFactory",
"BaseFactory",
"AutoModel",
]

25
astrai/config/__init__.py Normal file
View File

@ -0,0 +1,25 @@
from astrai.config.model_config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
)
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.config.train_config import TrainConfig
__all__ = [
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ConfigFactory",
"TrainConfig",
"InputConfig",
"OutputConfig",
"PipelineConfig",
"ProcessingConfig",
]

98
astrai/config/base.py Normal file
View File

@ -0,0 +1,98 @@
import json
from dataclasses import MISSING, dataclass, fields
from pathlib import Path
from typing import Any, Dict, Optional, Self, Union, get_type_hints
@dataclass
class BaseConfig:
def to_dict(self) -> Dict[str, Any]:
d = {}
for fld in fields(self):
v = getattr(self, fld.name)
if isinstance(v, (str, int, float, bool)):
d[fld.name] = v
elif v is None:
d[fld.name] = None
elif isinstance(v, (dict, list, tuple)):
try:
val = list(v) if isinstance(v, tuple) else v
json.dumps(val)
d[fld.name] = val
except (TypeError, ValueError):
pass
elif isinstance(v, BaseConfig):
d[fld.name] = v.to_dict()
elif hasattr(v, "__dataclass_fields__"):
sub = {}
for f in fields(v):
a = getattr(v, f.name)
sub[f.name] = list(a) if isinstance(a, tuple) else a
d[fld.name] = sub
return d
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
hints = get_type_hints(cls)
inst = cls.__new__(cls)
for fld in fields(cls):
if fld.name in d:
v = d[fld.name]
target = cls._unwrap_optional(hints.get(fld.name))
if target is not None:
try:
v = cls._coerce(v, target)
except (TypeError, ValueError):
pass
object.__setattr__(inst, fld.name, v)
elif fld.default is not MISSING:
object.__setattr__(inst, fld.name, fld.default)
elif fld.default_factory is not MISSING:
object.__setattr__(inst, fld.name, fld.default_factory())
else:
object.__setattr__(inst, fld.name, None)
return inst
@staticmethod
def _unwrap_optional(tp) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
return target_type.from_dict(value)
raise TypeError
@classmethod
def from_json(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -0,0 +1,92 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
class ConfigFactory(BaseFactory[BaseConfig]):
"""Factory that dispatches config classes by ``model_type``."""
@classmethod
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
model_type = raw.get("model_type") or "autoregressive_lm"
config_cls = cls.get_component_class(model_type)
return config_cls.from_dict(raw)
@dataclass
class BaseModelConfig(BaseConfig):
"""Base config with ``model_type`` dispatch and file I/O."""
model_type: Optional[str] = None
@classmethod
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
return cls.from_dict(raw)
def to_file(self, config_path: str):
d = self.to_dict()
config_dict = {k: v for k, v in d.items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@dataclass
@ConfigFactory.register("autoregressive_lm")
class AutoRegressiveLMConfig(BaseModelConfig):
"""Configuration for autoregressive language model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
attn_type: str = "gqa"
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None
@dataclass
@ConfigFactory.register("embedding")
class EncoderConfig(BaseModelConfig):
"""Configuration for embedding encoder model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
pooling_type: Optional[str] = None
normalize_embeddings: Optional[bool] = None

View File

@ -0,0 +1,37 @@
"""Pipeline configuration for JSONL preprocessing."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
sections: Optional[List[Dict]] = None
@dataclass
class ProcessingConfig(BaseConfig):
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
max_items: Optional[int] = None
@dataclass
class OutputConfig(BaseConfig):
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
@dataclass
class PipelineConfig(BaseConfig):
version: int = 1
input: InputConfig = field(default_factory=InputConfig)
mask: Dict[str, str] = field(default_factory=dict)
mask_default: str = "mask"
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
output: OutputConfig = field(default_factory=OutputConfig)

View File

@ -0,0 +1,140 @@
from dataclasses import dataclass, field, fields
from typing import Callable, List, Optional
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
from astrai.model.components.lora import LoRAConfig
def required(**kw):
return {"required": True, **kw}
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model factory for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(
default=None, metadata=required(help="Dataset for training.")
)
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata=required(help="Optimizer factory for training.")
)
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata=required(help="Scheduler factory for training.")
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field(
default=4, metadata={"help": "Batch size per device."}
)
grad_accum_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."}
)
gradient_checkpointing_modules: list = field(
default_factory=list,
metadata={"help": "Module types to enable activation checkpointing for."},
)
# checkpoint setting
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
start_batch: int = field(
default=0, metadata={"help": "Start batch iteration for training."}
)
ckpt_dir: str = field(
default="./checkpoint", metadata={"help": "Checkpoint directory."}
)
ckpt_interval: int = field(
default=5000, metadata={"help": "Number of iterations between checkpoints."}
)
# lora setting
lora: Optional[LoRAConfig] = field(
default=None,
metadata={"help": "LoRA config. None means full fine-tuning."},
)
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field(
default=0, metadata={"help": "Number of workers for dataloader."}
)
prefetch_factor: Optional[int] = field(
default=None, metadata={"help": "Prefetch factor for dataloader."}
)
pin_memory: bool = field(
default=False, metadata={"help": "Pin memory for dataloader."}
)
# distributed training
nprocs: int = field(
default=1, metadata={"help": "Number of processes for distributed training."}
)
backend: str = field(
default="nccl", metadata={"help": "Distributed training backend."}
)
master_addr: str = field(
default="localhost",
metadata={"help": "Master address for distributed training."},
)
master_port: str = field(
default="29500", metadata={"help": "Master port for distributed training."}
)
parallel_mode: str = field(
default="none",
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
)
start_method: str = field(
default="spawn",
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
)
# others
device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."}
)
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
executor_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
)
extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."}
)
def __post_init__(self):
self.validate()
def validate(self):
for fld in fields(self):
if fld.metadata.get("required") and getattr(self, fld.name) is None:
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")

View File

@ -0,0 +1,31 @@
from astrai.dataset.dataset import (
BaseDataset,
DatasetFactory,
)
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
H5Store,
MmapStore,
Store,
StoreFactory,
detect_format,
load_bin,
load_h5,
save_bin,
save_h5,
)
__all__ = [
"BaseDataset",
"DatasetFactory",
"Store",
"StoreFactory",
"H5Store",
"MmapStore",
"detect_format",
"save_h5",
"load_h5",
"save_bin",
"load_bin",
"ResumableDistributedSampler",
]

308
astrai/dataset/dataset.py Normal file
View File

@ -0,0 +1,308 @@
"""Dataset implementations with factory pattern for training."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import torch
from torch import Tensor
from torch.utils.data import Dataset
from astrai.dataset.storage import (
Store,
StoreFactory,
detect_format,
)
from astrai.factory import BaseFactory
class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types.
Implements common functionality for window-based data fetching.
Uses a storage abstraction for format-agnostic data loading.
"""
def __init__(self, window_size: int, stride: int):
super().__init__()
self.window_size = window_size
self.stride = stride
self.storage: Optional[Store] = None
@property
def required_keys(self) -> List[str]:
"""Return required storage keys for this dataset type.
Subclasses should override to specify expected keys.
"""
return []
def _validate_keys(self):
if not self.required_keys:
return
actual_keys = set(self.storage.keys)
missing = [k for k in self.required_keys if k not in actual_keys]
if missing:
raise KeyError(
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
f"Missing: {missing}"
)
def load(self, load_path: str, storage_type: Optional[str] = None):
"""Load dataset from the given path.
Auto-detects the storage format if not specified.
Args:
load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "bin"),
or None for auto-detection
Raises:
KeyError: If the loaded storage is missing required keys.
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StoreFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path)
self._validate_keys()
@property
def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset."""
if self.storage is None:
return 0
return len(self.storage)
@property
def keys(self) -> List[str]:
"""Return the available data keys."""
if self.storage is None:
return []
return self.storage.keys
def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample.
Args:
index: Sample index
Returns:
Tuple of (begin_idx, end_idx)
"""
if self.storage is None:
raise RuntimeError("Dataset not loaded, call load() first")
total = len(self.storage)
if total <= self.window_size:
raise ValueError(
f"Data too short: {total} tokens <= window_size {self.window_size}"
)
begin_idx = min(index * self.stride, total - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, total - 1)
return begin_idx, end_idx
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Get a single sample by index.
Must be implemented by subclasses.
"""
raise NotImplementedError
def __len__(self) -> int:
if self.storage is None:
return 0
total = len(self.storage)
if total <= self.window_size:
return 0
return (total - 1 - self.window_size) // self.stride + 1
class DatasetFactory(BaseFactory["BaseDataset"]):
"""Factory class for creating dataset instances.
Supports decorator-based registration for extensible dataset types.
All default dataset types (seq, sft, dpo, grpo) are registered automatically
when their classes are defined with the decorator.
Example usage:
@DatasetFactory.register("custom")
class CustomDataset(BaseDataset):
...
dataset = DatasetFactory.create("custom", window_size, stride)
"""
@classmethod
def _validate_component(cls, dataset_cls: type):
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
window_size: Window size for data sampling
stride: Stride between consecutive samples
Returns:
Dataset instance
"""
return super().create(train_type, window_size, stride)
@classmethod
def load(
cls,
train_type: str,
load_path: str,
window_size: int,
stride: Optional[int] = None,
storage_type: Optional[str] = None,
) -> "BaseDataset":
"""Create and load a dataset in one step.
Args:
train_type: Type of training dataset
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "bin") or None for auto-detection
Returns:
Loaded dataset instance
"""
if stride is None:
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type)
return dataset
@classmethod
def available_types(cls) -> list:
"""Return list of registered dataset type names."""
return cls.list_registered()
@DatasetFactory.register("seq")
class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence"]
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx).to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(dtype=torch.long)
return {"input_ids": x, "target_ids": y}
@DatasetFactory.register("sft")
class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence", "loss_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(
dtype=torch.long
)
loss_mask = self._fetch_data(begin_idx + 1, end_idx + 1, "loss_mask").to(
dtype=torch.bool
)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask}
@DatasetFactory.register("dpo")
class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index)
chosen = self._fetch_data(begin_idx, end_idx, "chosen").to(dtype=torch.long)
rejected = self._fetch_data(begin_idx, end_idx, "rejected").to(dtype=torch.long)
chosen_mask = self._fetch_data(begin_idx, end_idx, "chosen_mask").to(
dtype=torch.bool
)
rejected_mask = self._fetch_data(begin_idx, end_idx, "rejected_mask").to(
dtype=torch.bool
)
return {
"chosen": chosen,
"rejected": rejected,
"chosen_mask": chosen_mask,
"rejected_mask": rejected_mask,
}
@DatasetFactory.register("grpo")
class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["prompts", "responses", "masks", "rewards"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index)
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
dtype=torch.long
)
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
return {
"prompts": prompts,
"responses": responses,
"masks": masks,
"rewards": rewards,
}

84
astrai/dataset/sampler.py Normal file
View File

@ -0,0 +1,84 @@
from typing import Optional
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, Sampler
class ResumableDistributedSampler(Sampler[int]):
def __init__(
self,
data_source: Dataset,
start_epoch: int = 0,
start_iter: int = 0,
seed: int = 42,
drop_last: bool = False,
shuffle: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
):
self.epoch = start_epoch
self.iter = start_iter
self.seed = seed
self.num_samples = len(data_source)
if process_group is not None:
# input process group
self.rank = dist.get_rank(process_group)
self.num_replicas = dist.get_world_size(process_group)
elif dist.is_available() and dist.is_initialized():
# use default process group
process_group = dist.group.WORLD
self.rank = dist.get_rank()
self.num_replicas = dist.get_world_size()
else:
# single process
self.rank = 0
self.num_replicas = 1
self.drop_last = drop_last
self.shuffle = shuffle
offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas
self.iter = self.iter % self.num_samples_per_replica
self._indices = None
def _get_indices(self):
if self.shuffle:
generator = torch.Generator()
generator.manual_seed(self.seed + self.epoch)
indices = torch.randperm(self.num_samples, generator=generator).tolist()
else:
indices = torch.arange(self.num_samples).tolist()
if not self.drop_last and self.num_samples < self.total_size:
padding_size = self.total_size - len(indices)
indices += indices[:padding_size]
local_indices = indices[self.rank : self.total_size : self.num_replicas]
self.iter = self.iter % self.num_samples_per_replica
self._indices = local_indices[self.iter :]
def __iter__(self):
if self._indices is None:
self._get_indices()
for i in self._indices:
self.iter += 1
yield i
self.epoch += 1
self._indices = None
@property
def _remaining(self):
remaining = self.num_samples_per_replica - self.iter
return max(remaining, 0)
def __len__(self):
return self._remaining

264
astrai/dataset/storage.py Normal file
View File

@ -0,0 +1,264 @@
"""Storage backends for different data formats.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
fetch() uses bisect across segments no forced concat
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
Key properties:
- Multi-segment: segments kept as-is, no forced concatenation safe for
datasets larger than RAM
- Explicit length: _length = min(total elements across keys), set at load,
__len__ returns O(1)
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
workers share OS page-cache pages
"""
import bisect
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union
import h5py
import numpy as np
import torch
from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
with open(os.path.join(file_path, "meta.json"), "w") as f:
json.dump(meta, f)
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
with open(os.path.join(file_path, "meta.json"), "r") as f:
meta = json.load(f)
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r+",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
Args:
load_path: Directory or file path
Returns:
Format string ("h5" or "bin")
Raises:
FileNotFoundError: If no supported data files are found
"""
root = Path(load_path)
if root.is_file():
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files:
has_meta = (root / "meta.json").exists() or len(
list(root.rglob("meta.json"))
) > 0
if has_meta:
return "bin"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
"""
def __init__(self):
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod
def load(self, path: str) -> None:
raise NotImplementedError
@property
def keys(self) -> List[str]:
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = (
min((cum[-1] if cum else 0) for cum in self._cum.values())
if self._cum
else 0
)
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
...
"""
@classmethod
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StoreFactory.register("h5")
class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str):
self._normalize(load_h5(path))
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str):
self._mmap_refs = []
root = Path(path)
all_raw: Dict[str, List[Tensor]] = {}
meta_paths = list(root.rglob("meta.json"))
for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items():
if key not in all_raw:
all_raw[key] = []
all_raw[key].extend(tensors)
if not meta_paths:
raise FileNotFoundError(f"No meta.json found under {path}")
self._normalize(all_raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)

226
astrai/factory.py Normal file
View File

@ -0,0 +1,226 @@
"""Base factory class for extensible component registration."""
import inspect
from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
T = TypeVar("T")
class Registry:
"""Flexible registry for component classes with category and priority support.
This registry stores component classes with optional metadata (category, priority).
It provides methods for registration, retrieval, and listing with filtering.
"""
def __init__(self):
self._entries = {} # name -> (component_cls, category, priority)
def register(
self,
name: str,
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
):
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
self._entries[name] = (component_cls, category, priority)
def get(self, name: str) -> Type:
"""Get component class by name."""
if name not in self._entries:
raise KeyError(f"Component '{name}' not found in registry")
return self._entries[name][0]
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
"""Get component class with its metadata."""
entry = self._entries.get(name)
if entry is None:
raise KeyError(f"Component '{name}' not found in registry")
return entry
def contains(self, name: str) -> bool:
"""Check if a name is registered."""
return name in self._entries
def list_names(self) -> List[str]:
"""Return list of registered component names."""
return sorted(self._entries.keys())
def list_by_category(self, category: str) -> List[str]:
"""Return names of components belonging to a specific category."""
return sorted(
name for name, (_, cat, _) in self._entries.items() if cat == category
)
def list_by_priority(self, reverse: bool = False) -> List[str]:
"""Return names sorted by priority (default ascending)."""
return sorted(
self._entries.keys(),
key=lambda name: self._entries[name][2],
reverse=reverse,
)
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
"""Return raw entries dictionary."""
return self._entries.copy()
class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation.
This base class provides a decorator-based registration pattern
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
pass
@MyFactory.register("custom")
class CustomComponent(MyBaseClass):
...
component = MyFactory.create("custom", *args, **kwargs)
"""
_registry: Registry
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._registry = Registry()
@classmethod
def register(
cls, name: str, category: Optional[str] = None, priority: int = 0
) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class with optional category and priority.
Args:
name: Registration name for the component
category: Optional category for grouping components
priority: Priority for ordering (default 0)
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
"""
def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls)
cls._registry.register(
name, component_cls, category=category, priority=priority
)
return component_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs)
@classmethod
def _validate_component(cls, component_cls: Type[T]):
"""Validate that the component class is valid for this factory.
Override this method in subclasses to add custom validation.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
"""
pass
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
return cls._registry.get(name)
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.
Returns:
List of registered component names
"""
return cls._registry.list_names()
@classmethod
def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered.
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return cls._registry.contains(name)
@classmethod
def list_by_category(cls, category: str) -> List[str]:
"""List registered component names in a category."""
return cls._registry.list_by_category(category)
@classmethod
def list_by_priority(cls, reverse: bool = False) -> List[str]:
"""List registered component names sorted by priority."""
return cls._registry.list_by_priority(reverse)
__all__ = ["Registry", "BaseFactory"]

View File

@ -0,0 +1,85 @@
"""Inference module for continuous batching.
Layers:
- core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP orchestration (ProtocolHandler, server)
- protocols/: Response builders (OpenAI, Anthropic)
- transport/: SSE transport utilities
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
"""
from astrai.inference.api import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
GenContext,
MessagesRequest,
ProtocolHandler,
StopChecker,
app,
run_server,
)
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.core import (
STOP,
Allocator,
Executor,
InferenceScheduler,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
Task,
TaskManager,
TaskStatus,
TaskTable,
page_hash,
)
from astrai.inference.engine import GenerationRequest, InferenceEngine
from astrai.inference.sample import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
__all__ = [
"InferenceEngine",
"GenerationRequest",
"InferenceScheduler",
"Executor",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
"sample",
"BaseSamplingStrategy",
"TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
"ProtocolHandler",
"StopChecker",
"GenContext",
"OpenAIResponseBuilder",
"AnthropicResponseBuilder",
"ChatMessage",
"ChatCompletionRequest",
"AnthropicMessage",
"MessagesRequest",
"app",
"run_server",
]

View File

@ -0,0 +1,23 @@
"""Inference API: protocol handler, stop checker, and FastAPI server."""
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
from astrai.inference.api.server import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
app,
run_server,
)
__all__ = [
"ProtocolHandler",
"StopChecker",
"GenContext",
"AnthropicMessage",
"ChatCompletionRequest",
"ChatMessage",
"MessagesRequest",
"app",
"run_server",
]

View File

@ -0,0 +1,141 @@
"""Anthropic message completion response builder."""
import time
import uuid
from typing import Any, Dict, List, Tuple, Union
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class AnthropicResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages: List[Dict[str, str]] = []
system = getattr(request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in request.messages:
text = _extract_text(m.content)
if text:
messages.append({"role": m.role, "content": text})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
ctx = GenContext(
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
created=int(time.time()),
model=request.model,
prompt_tokens=0,
)
stop_sequences = getattr(request, "stop_sequences", None) or []
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
events: List[str] = []
if stop.matched:
trimmed = stop.body[: stop.body.rfind(stop.matched)]
unyielded = trimmed[len(stop.yielded) :]
if unyielded:
events.append(
sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": unyielded},
},
event="content_block_delta",
)
)
events.append(
sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
if stop.matched:
content = content[: content.rfind(stop.matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
"stop_sequence": stop.matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -0,0 +1,140 @@
"""OpenAI chat completion response builder."""
import logging
import time
import uuid
from typing import Any, Dict, List, Tuple
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__)
_UNSUPPORTED_PARAMS = (
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
)
class OpenAIResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages = [{"role": m.role, "content": m.content} for m in request.messages]
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model
for param in _UNSUPPORTED_PARAMS:
value = getattr(request, param, None)
fields = getattr(type(request), "model_fields", {})
default = fields[param].default if param in fields else None
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param,
value,
)
ctx = GenContext(
resp_id=self._resp_id,
created=int(time.time()),
model=self._model,
prompt_tokens=0,
)
stop = request.stop
stop_sequences = (
[] if stop is None else [stop] if isinstance(stop, str) else stop
)
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_chunk(self, token: str) -> str:
return sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}

View File

@ -0,0 +1,182 @@
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
ProtocolHandler orchestrates the async generation loop and delegates
protocol-specific formatting to a ResponseBuilder.
"""
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
def sse_done() -> str:
return "data: [DONE]\n\n"
@dataclass
class GenContext:
"""Per-generation metadata passed to builder format methods."""
resp_id: str
created: int
model: str
prompt_tokens: int
completion_tokens: int = 0
@dataclass
class StopInfo:
"""Stop-check result passed to format_stream_end / format_response."""
matched: Optional[str] = None
body: str = ""
yielded: str = ""
class StopChecker:
"""Scans accumulated text for stop sequence matches."""
def __init__(self, sequences: List[str]):
self._sequences = [s for s in sequences if s]
def check(self, text: str) -> Optional[str]:
for seq in self._sequences:
if seq in text:
return seq
return None
class ResponseBuilder(ABC):
"""Interface for protocol-specific response formatting.
A new protocol requires one concrete builder implementing 5 methods.
"""
@abstractmethod
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
"""Return (prompt, ctx, stop_sequences) for a generation request."""
@abstractmethod
def format_stream_start(self, ctx: GenContext) -> List[str]:
"""SSE events that open the stream."""
@abstractmethod
def format_chunk(self, token: str) -> str:
"""SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
"""SSE events that close the stream."""
@abstractmethod
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
"""JSON response body for non-streaming mode."""
class ProtocolHandler:
"""Orchestrates the generation loop, delegates formatting to a builder.
Usage::
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
response = await handler.handle()
"""
def __init__(
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
):
self.request = request
self.engine = engine
self.builder = builder
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
agen = self.engine.generate_async(
prompt=prompt,
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
top_k=self.request.top_k,
)
if self.request.stream:
return self._handle_stream(agen, ctx, stop_sequences)
else:
return await self._handle_non_stream(agen, ctx, stop_sequences)
def _handle_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> StreamingResponse:
checker = StopChecker(stop_sequences)
async def event_stream():
for event in self.builder.format_stream_start(ctx):
yield event
body = ""
yielded = ""
matched = None
async for token in agen:
body += token
matched = checker.check(body)
if matched:
break
ctx.completion_tokens += 1
yield self.builder.format_chunk(token)
yielded += token
stop = StopInfo(matched=matched, body=body, yielded=yielded)
for event in self.builder.format_stream_end(ctx, stop):
yield event
yield sse_done()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
async def _handle_non_stream(
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
) -> Dict[str, Any]:
checker = StopChecker(stop_sequences)
chunks: List[str] = []
body = ""
matched = None
async for token in agen:
chunks.append(token)
body += token
matched = checker.check(body)
if matched:
break
ctx.completion_tokens += 1
content = "".join(chunks)
stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop)

View File

@ -0,0 +1,169 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import ProtocolHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
def _get_engine() -> InferenceEngine:
engine = app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": app.state.engine is not None,
}
@app.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
return await handler.handle()
@app.post("/v1/messages")
async def create_message(request: MessagesRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
return await handler.handle()
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
app,
host=host,
port=port,
reload=reload,
)

View File

@ -0,0 +1,32 @@
"""Inference core: cache, executor, scheduler, task management."""
from astrai.inference.core.cache import (
Allocator,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
from astrai.inference.core.executor import Executor
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
__all__ = [
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
"Executor",
"InferenceScheduler",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
]

View File

@ -0,0 +1,368 @@
import threading
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class Allocator:
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self.on_evict: Optional[Callable[[int], None]] = None
self._lock = threading.Lock()
def alloc(self) -> int:
with self._lock:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self.on_evict:
self.on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False):
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
self._lru[idx] = None
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int):
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
def ref_count(self, idx: int) -> int:
with self._lock:
return self._refs[idx]
def touch(self, idx: int):
with self._lock:
self._lru.move_to_end(idx)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
def __init__(self, page_size: int):
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int):
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
with self._lock:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int]) -> List[int]:
with self._lock:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self._page_size)
p = self._hash_to_page.get(h)
if p is None:
break
hits.append(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
class PagePool:
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
def __init__(self, allocator: Allocator, prefix: PrefixCache):
self._alloc = allocator
self._prefix = prefix
self._alloc.on_evict = prefix.evict
@property
def allocator(self) -> Allocator:
return self._alloc
@property
def prefix(self) -> PrefixCache:
return self._prefix
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int):
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int):
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
hits = self._prefix.lookup(token_ids)
for p in hits:
self._alloc.touch(p)
return hits
def record(self, page_idx: int, token_ids: List[int], logical_page_idx: int):
self._prefix.record(page_idx, token_ids, logical_page_idx)
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, page_size: int):
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int):
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
with self._lock:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
with self._lock:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def get_ref(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.setdefault(task_id, [])
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
with self._lock:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class Storage:
"""KV-cache tensor storage with paged write/gather."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def write(
self,
layer_id: int,
page_table: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
):
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
valid = phys_pages >= 0
if not valid.all():
if valid.any():
valid_pages = phys_pages[valid]
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
valid, written : written + chunk
]
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
valid, written : written + chunk
]
written += chunk
continue
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(
self, layer_id: int, page_table: Tensor, total_len: int
) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
if (page_table < 0).any():
invalid = (
(page_table < 0)
.unsqueeze(-1)
.expand(-1, -1, self.page_size)
.flatten(1, 2)
)
invalid = invalid[:, :, None, None].expand_as(k)
k = k.masked_fill(invalid, 0.0)
v = v.masked_fill(invalid, 0.0)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v
class KvcacheView:
"""Bundles Storage + page_table + total_len for attention layers."""
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
self._storage = storage
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor):
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._storage.gather(layer_id, self._page_table, self._total_len)
class KVCache:
"""Facade: page management + KV-cache I/O for continuous batching."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
self._table = TaskTable(page_size)
self._storage = Storage(
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._pool.lookup(prompt_ids)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self._pool.free(hp)
for np in new_pages:
self._pool.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str):
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
):
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._pool.record(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
return KvcacheView(self._storage, page_table, total_len)

View File

@ -0,0 +1,94 @@
import logging
from typing import List, Optional
import torch
from astrai.inference.core.cache import KVCache
from astrai.inference.core.task import Task
from astrai.inference.sample import sample
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class Executor:
"""Model forward passes for prefill and decode phases."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: KVCache,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.model = model
self.tokenizer = tokenizer
self.page_cache = page_cache
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(self, tasks: List[Task], prompt_len: int, start_pos: int = 0):
if start_pos >= prompt_len:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
dtype=torch.long,
device=self.device,
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
def execute_decode(self, tasks: List[Task]) -> List[int]:
if not tasks:
return []
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
position_ids = torch.tensor(
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
)
total_len = position_ids.max().item() + 1
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
position_ids=position_ids.unsqueeze(1),
)
logits = outputs["logits"][:, -1, :]
return sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()

View File

@ -0,0 +1,212 @@
import logging
import threading
from typing import Any, Dict, List, Optional, Tuple
import torch
from astrai.inference.core.cache import KVCache
from astrai.inference.core.executor import Executor
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class InferenceScheduler:
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
n_pages = (
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
self._page_cache = KVCache(
config.n_layers,
n_pages,
page_size,
config.n_kv_heads,
config.dim // config.n_heads,
self.device,
self.dtype,
)
self._task_mgr = TaskManager(
tokenizer=tokenizer,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len,
)
self._executor = Executor(
model=model,
tokenizer=tokenizer,
page_cache=self._page_cache,
device=self.device,
dtype=self.dtype,
)
self._running = False
self._fatal_error: Optional[Exception] = None
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str):
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self):
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished:
self._page_cache.task_free(task.task_id)
active = self._task_mgr.get_active_tasks()
available = self._task_mgr.max_batch_size - len(active)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
for task in candidates:
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
self._task_mgr.activate(task)
else:
failed.append(task)
if failed:
self._task_mgr.return_to_waiting(failed)
if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0)
continue
to_prefill = [
t
for t in self._task_mgr.get_active_tasks()
if t.output_tokens == 0
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
]
if to_prefill:
for t in to_prefill:
t.input_tokens = len(t.prompt_ids)
groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill:
key = (
len(t.prompt_ids),
self._page_cache.task_cached(t.task_id),
)
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos)
start_logical_page = start_pos // self._page_cache.page_size
for t in group:
self._page_cache.task_record_hashes(
t.task_id,
t.prompt_ids,
start_logical_page=start_logical_page,
)
pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.get_active_tasks():
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
valid: List[Task] = []
for t in group:
if self._page_cache.task_extend(t.task_id, t.next_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if valid:
next_tokens = self._executor.execute_decode(valid)
for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
extend_ok = self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
if not extend_ok:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
for t in valid:
if t.is_finished(stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
except Exception as e:
self._fatal_error = e
self._running = False
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
def start(self):
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self):
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -0,0 +1,209 @@
import logging
import threading
import time
import uuid
from collections import deque
from enum import Enum
from typing import Any, Callable, Deque, Dict, List, Optional
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
STOP = object()
class TaskStatus(Enum):
"""Task lifecycle states."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Single generation request: prompt, sampling params, output state."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
class TaskManager:
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
def __init__(
self,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: int = 8192,
max_prompt_len: int = 512,
):
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.max_prompt_len = max_prompt_len
self.waiting_queue: Deque[Task] = deque()
self.active_tasks: List[Task] = []
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) >= self.max_seq_len:
if stream_callback:
stream_callback(STOP)
return task_id
if max_tokens is None:
max_tokens = self.max_seq_len - len(prompt_ids)
else:
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> List[Task]:
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = deque(
t for t in self.waiting_queue if t.task_id != task_id
)
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
return removed_active
def get_stats(self) -> Dict[str, Any]:
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
with self._lock:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
return finished
def pull_candidates(self, n: int) -> List[Task]:
to_add: List[Task] = []
with self._lock:
take = min(n, len(self.waiting_queue))
for _ in range(take):
to_add.append(self.waiting_queue.popleft())
return to_add
def activate(self, task: Task):
task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]):
with self._lock:
for task in reversed(tasks):
self.waiting_queue.appendleft(task)
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0):
with self._lock:
if self.waiting_queue or self.active_tasks:
return
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:
with self._lock:
return list(self.active_tasks)
def get_waiting_tasks(self) -> List[Task]:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self):
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self):
self._task_event.set()

288
astrai/inference/engine.py Normal file
View File

@ -0,0 +1,288 @@
"""Unified inference engine for continuous batching."""
import asyncio
import gc
import threading
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
def __init__(self, count: int = 1):
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[Tuple[int, str]] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self, timeout: float = 300.0):
with self._cond:
if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
):
raise TimeoutError(
f"Generation timeout after {timeout}s "
f"({self._completed}/{self._total} completed)"
)
def get_results(self) -> List[str]:
with self._cond:
return self.results.copy()
class GenerationRequest:
"""Request parameters for text generation."""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
stream: bool = False,
):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature > 0):
raise ValueError("temperature must be a positive number")
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler."""
def __init__(
self,
model: nn.Module,
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 128,
):
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
page_size=page_size,
)
self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
return False
def generate(
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
self,
prompt: str,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
try:
return next(gen)
except StopIteration:
return None
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _submit_tasks(
self,
prompts: List[str],
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = GenerateResult(count=n)
task_ids = []
for i, p in enumerate(prompts):
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=cb,
)
task_ids.append(task_id)
return result, task_ids
@staticmethod
def _make_callback(result: GenerateResult, idx: int):
def cb(token):
result.append(token, idx)
return cb
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Generator:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts)
remaining = n
finished = [False] * n
def gen():
nonlocal remaining
try:
while remaining > 0:
items = result.pop_all()
for idx, token in items:
if token is STOP:
if not finished[idx]:
finished[idx] = True
remaining -= 1
else:
yield (idx, token) if is_batch else token
if remaining > 0:
result.wait(timeout=0.05)
finally:
for tid in task_ids:
self.scheduler.remove_task(tid)
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
try:
result.wait_completion()
except TimeoutError:
for tid in task_ids:
self.scheduler.remove_task(tid)
raise
for tid in task_ids:
self.scheduler.remove_task(tid)
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats()
def shutdown(self):
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

190
astrai/inference/sample.py Normal file
View File

@ -0,0 +1,190 @@
"""Composable sampling strategies for logit transformation.
Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline.
All strategies accept both scalar and per-sample tensor
parameters, so a single pipeline works for any batch size.
"""
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from torch import Tensor
class BaseSamplingStrategy(ABC):
"""Abstract base for a logit transformation strategy."""
@abstractmethod
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Applies the strategy to logits.
Args:
logits: Raw logits tensor (batch, vocab_size).
filter_value: Value assigned to filtered-out positions.
Returns:
Transformed logits tensor.
"""
class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness.
Args:
temperature: Scalar or ``[batch]`` tensor.
"""
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
t = self.temperature
if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
t = torch.clamp(t, min=1e-8)
if (t != 1.0).any():
logits = logits / t
elif t != 1.0:
logits = logits / max(t, 1e-8)
return logits
class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value.
Args:
top_k: Scalar or ``[batch]`` tensor (0 disables).
"""
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
tk = self.top_k
if isinstance(tk, Tensor):
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
max_k = int(tk.max().item())
if max_k <= 0:
return logits
max_k = min(max_k, logits.size(-1))
values, _ = torch.topk(logits, max_k, dim=-1)
per_row_k = tk.clamp(max=max_k)
thresholds = torch.full_like(logits[..., -1:], -float("inf"))
positive = per_row_k > 0
if positive.any():
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
thresholds[positive] = values[
row_idx, per_row_k[positive] - 1
].unsqueeze(-1)
logits[logits < thresholds] = filter_value
return logits
if tk > 0:
k = min(tk, logits.size(-1))
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value
return logits
class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p.
Args:
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
"""
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, remove)
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
if (tp < 1.0).any():
logits = self._apply(logits, tp.view(-1, 1), filter_value)
elif tp < 1.0:
logits = self._apply(logits, tp, filter_value)
return logits
class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided,
matching the original temperature -> top-k -> top-p ordering.
Usage::
pipeline = SamplingPipeline([
TemperatureStrategy(0.8),
TopKStrategy(50),
TopPStrategy(0.95),
])
logits = pipeline.apply(logits)
token = pipeline.sample(logits) # softmax + multinomial
"""
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits
@torch.no_grad()
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Apply strategies then sample (softmax + multinomial).
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return torch.multinomial(
torch.softmax(self.apply(logits, filter_value), dim=-1),
num_samples=1,
).squeeze(-1)
@torch.inference_mode()
def sample(
logits: Tensor,
temperature: Union[float, Tensor] = 1.0,
top_k: Union[int, Tensor] = 0,
top_p: Union[float, Tensor] = 1.0,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies then sample (softmax + multinomial).
Shortcut for ``SamplingPipeline(...).sample(logits)``.
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return SamplingPipeline(
[
TemperatureStrategy(temperature),
TopKStrategy(top_k),
TopPStrategy(top_p),
]
).sample(logits, filter_value)

34
astrai/model/__init__.py Normal file
View File

@ -0,0 +1,34 @@
from astrai.model.automodel import AutoModel
from astrai.model.components.attention import GQA
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.lora import (
LoRAConfig,
inject_lora,
load_lora,
merge_lora,
save_lora,
)
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
__all__ = [
# Modules
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
# Models
"AutoRegressiveLM",
"EmbeddingEncoder",
"AutoModel",
# LoRA
"LoRAConfig",
"inject_lora",
"merge_lora",
"save_lora",
"load_lora",
]

95
astrai/model/automodel.py Normal file
View File

@ -0,0 +1,95 @@
"""
AutoModel base class for model loading and saving.
"""
from contextlib import contextmanager
from pathlib import Path
from typing import Self, Union
import torch.nn as nn
from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.factory import BaseFactory
from astrai.serialization import load_model_config, load_model_weights, save_model
@contextmanager
def _disable_random_init(enable: bool = True):
if not enable:
yield
return
names = (
"xavier_normal_",
"xavier_uniform_",
"kaiming_normal_",
"kaiming_uniform_",
"zeros_",
"ones_",
"constant_",
"normal_",
"uniform_",
)
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
for n in orig:
setattr(nn.init, n, lambda *a, **kw: None)
try:
yield
finally:
for n, fn in orig.items():
setattr(nn.init, n, fn)
class AutoModel(BaseFactory["AutoModel"], nn.Module):
"""
Autoregressive language model base class.
Provides model loading/saving, registration, and generation.
"""
def __init__(self, config: BaseModelConfig):
super().__init__()
self.config = config
@classmethod
def from_pretrained(
cls,
path: Union[str, Path],
disable_random_init: bool = True,
strict: bool = True,
) -> nn.Module:
model_path = Path(path)
config_path = model_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
raw = load_model_config(str(model_path))
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)
weights_path = model_path / "model.safetensors"
if weights_path.exists():
state_dict = load_model_weights(str(model_path))
model.load_state_dict(state_dict, strict=strict)
return model
def save_pretrained(
self,
save_directory: Union[str, Path],
):
save_model(
config=self.config.to_dict(),
state_dict=self.state_dict(),
save_directory=str(save_directory),
)
def to(self, *args, **kwargs) -> Self:
"""Move model to device/dtype."""
return super().to(*args, **kwargs)

View File

@ -0,0 +1,25 @@
from astrai.model.components.attention import GQA, MLA, repeat_kv
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import (
RotaryEmbedding,
apply_rotary_emb,
get_rotary_emb,
)
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"Embedding",
"GQA",
"MLA",
"DecoderBlock",
"RotaryEmbedding",
"apply_rotary_emb",
"get_rotary_emb",
"repeat_kv",
]

View File

@ -0,0 +1,212 @@
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import apply_rotary_emb
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_heads, n_rep, head_dim)
.reshape(bs, slen, n_heads * n_rep, head_dim)
)
class AttnFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, attn_type: str, **kwargs) -> nn.Module:
return super().create(attn_type, **kwargs)
@AttnFactory.register("gqa")
class GQA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
use_qk_norm: bool,
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
assert dim % n_heads == 0
assert n_heads % n_kv_heads == 0
self.head_dim = dim // n_heads
self.layer_id = layer_id
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim)
self.k_proj = Linear(dim, n_kv_heads * self.head_dim)
self.v_proj = Linear(dim, n_kv_heads * self.head_dim)
self.o_proj = Linear(dim, dim)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
if self.use_gated_attention:
self.gate = Linear(dim, dim)
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
is_causal = attn_mask is None
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
q, k = apply_rotary_emb(q, rotary_emb), apply_rotary_emb(k, rotary_emb)
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
)
if self.use_gated_attention:
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
@AttnFactory.register("mla")
class MLA(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
self.layer_id = layer_id
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (2 * self.head_dim),
)
self.o_proj = Linear(dim, dim, bias=False)
if use_gated_attention:
self.gate = Linear(dim, dim, bias=False)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = attn_mask is None
q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
kv_compressed = self.kv_a_proj(x)
kv_compressed = self.kv_norm(kv_compressed)
kv = self.kv_b_proj(kv_compressed)
kv = kv.view(bsz, seq_len, self.n_kv_heads, -1)
k_nope, k_rope, v = torch.split(
kv, [self.qk_nope_head_dim, self.qk_rope_head_dim, self.head_dim], dim=-1
)
q_nope, q_rope = (
q[..., : self.qk_nope_head_dim],
q[..., self.qk_nope_head_dim :],
)
q_rope = apply_rotary_emb(q_rope, rotary_emb)
k_rope = apply_rotary_emb(k_rope, rotary_emb)
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask, is_causal=is_causal
)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention:
attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out)
return out

View File

@ -0,0 +1,59 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import AttnFactory
from astrai.model.components.mlp import FFNFactory
from astrai.model.components.norm import RMSNorm
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**kwargs,
):
super().__init__()
self.attention = AttnFactory.create(
attn_type,
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
**kwargs,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -0,0 +1,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -0,0 +1,21 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / (fan_in**0.5)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -0,0 +1,194 @@
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional, Set
import torch
import torch.nn as nn
import torch.nn.functional as F
from astrai.model.components.linear import Linear
from astrai.serialization import (
load_json,
load_safetensors,
save_json,
save_safetensors,
)
logger = logging.getLogger(__name__)
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
TARGET_MODULES_FFN = {"up", "gate", "down"}
@dataclass
class LoRAConfig:
r: int = 16
alpha: int = 32
target_modules: tuple = ("q_proj", "v_proj")
class LoRALinear(nn.Module):
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
super().__init__()
self.register_parameter("weight", base.weight)
self.weight.requires_grad_(False)
self.bias = base.bias
if self.bias is not None:
self.bias.requires_grad_(False)
self.r = r
self.scaling = alpha / r
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
self._merged = False
def forward(self, x):
out = F.linear(x, self.weight, self.bias)
if not self._merged:
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
return out
def merge(self):
if self._merged:
return
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self._merged = True
del self.lora_A
del self.lora_B
def _collect_lora_info(model: nn.Module) -> dict:
names = {}
for n, m in model.named_modules():
if isinstance(m, Linear):
_, _, child = n.rpartition(".")
names.setdefault(child, []).append(n)
return names
def _get_lora_count(model: nn.Module) -> int:
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
def inject_lora(
model: nn.Module,
r: int = 16,
alpha: int = 32,
target_modules: Optional[Set[str]] = None,
) -> LoRAConfig:
if target_modules is None:
target_modules = TARGET_MODULES_ATTN
available = _collect_lora_info(model)
injected = 0
for name, module in list(model.named_modules()):
if not isinstance(module, Linear):
continue
parent_name, _, child_name = name.rpartition(".")
if child_name not in target_modules:
continue
parent = model.get_submodule(parent_name) if parent_name else model
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
injected += 1
if injected == 0:
logger.warning(
"No LoRA layers injected. Available Linear child names: %s. "
"target_modules: %s. Check model type and target_modules.",
sorted(available),
sorted(target_modules),
)
else:
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
def merge_lora(model: nn.Module):
n = 0
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
n += 1
if n == 0:
logger.warning("No LoRA layers to merge.")
else:
logger.info("Merged %d LoRA layers", n)
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
lora_sd = {
k: v
for k, v in model.state_dict().items()
if k.endswith((".lora_A", ".lora_B"))
}
if not lora_sd:
raise RuntimeError(
"No LoRA parameters found in model. "
"The model may not have been injected or was already merged."
)
path = Path(save_dir)
path.mkdir(parents=True, exist_ok=True)
save_safetensors(lora_sd, path / "adapter_model.safetensors")
save_json(asdict(config), path / "adapter_config.json")
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
path = Path(load_dir)
raw = load_json(path / "adapter_config.json")
config = LoRAConfig(
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
)
existing = _get_lora_count(model)
if existing > 0:
logger.warning(
"Model already has %d LoRA layers. Skipping injection, "
"loading weights onto existing layers only.",
existing,
)
else:
inject_lora(
model,
r=config.r,
alpha=config.alpha,
target_modules=set(config.target_modules),
)
weights = load_safetensors(path / "adapter_model.safetensors")
try:
missing, unexpected = model.load_state_dict(weights, strict=False)
except RuntimeError as e:
msg = str(e)
if "size mismatch" in msg:
raise RuntimeError(
f"LoRA weight shapes do not match the model. "
f"The adapter config (r={config.r}) may not match the injected layers. "
f"Original error: {msg}"
) from e
raise
injected = _get_lora_count(model)
if injected == 0:
raise RuntimeError(
"No LoRA layers found after loading. "
"Inject LoRA before calling load_lora, or check the adapter config."
)
if missing:
lora_missing = [k for k in missing if "lora" in k]
if lora_missing:
raise RuntimeError(
f"LoRA weight keys not found in model: {lora_missing}. "
f"The adapter config (r={config.r}) may not match the model."
)
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
if unexpected:
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
logger.info("LoRA adapter loaded from %s", load_dir)
return config

View File

@ -0,0 +1,93 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.up = Linear(dim, dim_ffn)
self.gate = Linear(dim, dim_ffn)
self.down = Linear(dim_ffn, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@FFNFactory.register("moe")
class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_ffn: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
):
super().__init__()
self.dim = dim
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.n_activated_experts = n_activated_experts
self.topk_method = topk_method
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:
bsz, seq_len, dim = x.shape
x_flat = x.view(-1, dim)
shared_out = self._shared_forward(x_flat)
routed_out = self._routed_forward(x_flat)
out = (shared_out + routed_out).view(bsz, seq_len, dim)
return out
def _shared_forward(self, x: Tensor) -> Tensor:
if self.n_shared_experts == 0:
return torch.zeros_like(x)
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
def _routed_forward(self, x: Tensor) -> Tensor:
N, D = x.shape
K = self.n_activated_experts
router_logits = self.router(x)
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for expert_idx in range(self.n_routed_experts):
expert_mask = topk_indices == expert_idx
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
expert_input = x[token_idx]
expert_output = self.routed_experts[expert_idx](expert_input)
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
output.index_add_(0, token_idx, expert_output * weights)
return output

View File

@ -0,0 +1,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)

View File

@ -0,0 +1,71 @@
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch import Tensor
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def ntk_base(base: float, dim: int, factor: float) -> float:
return base * (factor ** (dim / (dim - 2)))
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim: int,
max_len: int,
base: float = 10000,
rope_scaling: Optional[Dict] = None,
):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self.rope_scaling = rope_scaling
if rope_scaling is not None:
scaling_type = rope_scaling.get("type", "ntk")
factor = rope_scaling.get("factor", 1.0)
if scaling_type == "ntk":
self.base = ntk_base(base, dim, factor)
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)

99
astrai/model/encoder.py Normal file
View File

@ -0,0 +1,99 @@
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import EncoderConfig
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
from astrai.model.transformer import process_attention_mask
@AutoModel.register("embedding")
class EmbeddingEncoder(AutoModel):
def __init__(self, config: EncoderConfig):
super().__init__(config)
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.pooling_type = config.pooling_type or "mean"
self.normalize_embeddings = config.normalize_embeddings or False
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
state_dict = dict(state_dict)
state_dict.pop("lm_head.weight", None)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
B, S = input_ids.shape
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
hidden_states = self.norm(x)
if self.pooling_type == "cls":
pooled = hidden_states[:, 0]
elif self.pooling_type == "last":
if input_mask is not None:
lengths = input_mask.sum(dim=1) - 1
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
else:
pooled = hidden_states[:, -1]
else:
if input_mask is not None:
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
min=1.0
)
else:
pooled = hidden_states.mean(dim=1)
if self.normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
return pooled

152
astrai/model/transformer.py Normal file
View File

@ -0,0 +1,152 @@
from typing import Any, Dict, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
def process_attention_mask(
input_tensor: Tensor,
position_ids: Optional[Tensor],
input_mask: Optional[Tensor] = None,
is_causal: bool = False,
) -> Optional[Tensor]:
if position_ids is None:
return None
if input_mask is not None and input_mask.dim() > 2:
return input_mask
device = input_tensor.device
dtype = input_tensor.dtype
B, S = input_tensor.size()[:2]
T = position_ids.max().item() + 1
if input_mask is None:
if position_ids.min().item() == 0 and is_causal:
return None
pad = torch.ones(B, T, dtype=torch.bool, device=device)
else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
attend = pad.view(B, 1, T).expand(B, S, T).clone()
if is_causal:
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
return torch.full(
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("autoregressive_lm")
class AutoRegressiveLM(AutoModel):
"""Autoregressive language model with paged KV cache."""
def __init__(self, config: AutoRegressiveLMConfig):
super().__init__(config)
self.config = config
rope_dim = (
config.qk_rope_head_dim
if config.attn_type == "mla"
else config.dim // config.n_heads
)
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
attn_type=config.attn_type,
ffn_type=config.ffn_type,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.topk_method,
kv_lora_rank=config.kv_lora_rank,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight is True:
self.lm_head.weight = self.embed_tokens.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
state_dict = dict(state_dict)
if self.config.tie_weight is True:
# same tensor for embed and lm_head
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
else:
if lm_head_key not in state_dict and embed_key in state_dict:
# clone to avoid sharing gradients
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign)
def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight is True:
lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict:
del state_dict[lm_head_key]
return state_dict
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)
return {"logits": logits, "hidden_states": hidden_states}

View File

@ -0,0 +1,38 @@
from astrai.parallel.executor import (
AccumOptimizer,
AccumScheduler,
BaseExecutor,
DDPExecutor,
ExecutorFactory,
FSDPExecutor,
GradientState,
NoneExecutor,
)
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import (
get_current_device,
get_rank,
get_world_size,
only_on_rank,
setup_parallel,
spawn_parallel_fn,
)
__all__ = [
"get_world_size",
"get_rank",
"get_current_device",
"only_on_rank",
"setup_parallel",
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear",
"ExecutorFactory",
"BaseExecutor",
"GradientState",
"AccumOptimizer",
"AccumScheduler",
"NoneExecutor",
"DDPExecutor",
"FSDPExecutor",
]

271
astrai/parallel/executor.py Normal file
View File

@ -0,0 +1,271 @@
"""Unified training executor — parallel strategy + gradient accumulation."""
import contextlib
import logging
from contextlib import contextmanager
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.factory import BaseFactory
from astrai.parallel.setup import get_rank, get_world_size
logger = logging.getLogger(__name__)
class GradientState:
def __init__(self, grad_accum_steps: int = 1):
self.num_steps = max(grad_accum_steps, 1)
self._step: int = 0
self._sync_gradients: bool = True
@property
def sync_gradients(self) -> bool:
return self._sync_gradients
def _do_sync(self):
self._step += 1
self._sync_gradients = self._step % self.num_steps == 0
class AccumOptimizer:
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
self.optimizer = optimizer
self.gradient_state = gradient_state
def step(self, closure=None):
if self.gradient_state.sync_gradients:
self.optimizer.step(closure)
def zero_grad(self):
if self.gradient_state.sync_gradients:
self.optimizer.zero_grad()
@property
def param_groups(self):
return self.optimizer.param_groups
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, d):
self.optimizer.load_state_dict(d)
class AccumScheduler:
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
self.scheduler = scheduler
self.gradient_state = gradient_state
def step(self):
if self.gradient_state.sync_gradients:
self.scheduler.step()
def state_dict(self):
return self.scheduler.state_dict()
def load_state_dict(self, d):
self.scheduler.load_state_dict(d)
def get_last_lr(self):
return self.scheduler.get_last_lr()
class BaseExecutor:
def __init__(self, grad_accum_steps: int = 1):
self.gradient_state = GradientState(grad_accum_steps)
def prepare(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
dataloader: Optional[DataLoader] = None,
scheduler: Optional[LRScheduler] = None,
) -> Tuple[
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
]:
model = self._prepare_model(model)
if optimizer is not None:
optimizer = AccumOptimizer(optimizer, self.gradient_state)
if scheduler is not None:
scheduler = AccumScheduler(scheduler, self.gradient_state)
return model, optimizer, dataloader, scheduler
def _prepare_model(self, model: nn.Module) -> nn.Module:
return model
def _no_sync(self, model: nn.Module):
return contextlib.nullcontext()
@contextmanager
def accumulate(self, model: nn.Module):
self.gradient_state._do_sync()
if not self.gradient_state.sync_gradients:
with self._no_sync(model):
yield
else:
yield
def backward(self, loss: torch.Tensor):
loss.backward()
def unwrap_model(self, model: nn.Module):
return model.state_dict()
@property
def use_distributed(self) -> bool:
return get_world_size() > 1
@property
def sync_gradients(self) -> bool:
return self.gradient_state.sync_gradients
@property
def grad_accum_steps(self) -> int:
return self.gradient_state.num_steps
class ExecutorFactory(BaseFactory[BaseExecutor]):
pass
@ExecutorFactory.register("none")
class NoneExecutor(BaseExecutor):
pass
@ExecutorFactory.register("ddp")
class DDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
dim: int = 0,
broadcast_buffers: bool = True,
init_sync: bool = True,
process_group=None,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
delay_all_reduce_named_params=None,
param_to_hook_all_reduce=None,
mixed_precision=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._ddp_kwargs = dict(
dim=dim,
broadcast_buffers=broadcast_buffers,
init_sync=init_sync,
process_group=process_group,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
delay_all_reduce_named_params=delay_all_reduce_named_params,
param_to_hook_all_reduce=param_to_hook_all_reduce,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
)
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("DDP backend selected but world_size=1, model not wrapped")
return model
local_rank = get_rank()
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
**self._ddp_kwargs,
)
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, DDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,
process_group=None,
sharding_strategy=None,
cpu_offload=None,
auto_wrap_policy=None,
backward_prefetch=None,
mixed_precision=None,
ignored_modules=None,
param_init_fn=None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = True,
ignored_states=None,
device_mesh=None,
):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = {
k: v
for k, v in dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=True,
ignored_states=ignored_states,
device_mesh=device_mesh,
).items()
if v is not None
}
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
return model
self._original_model = model
device_id = torch.device("cuda", get_rank())
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, FSDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module):
if isinstance(model, FSDP) and self.use_distributed:
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
return model.state_dict()
return model.state_dict()

115
astrai/parallel/module.py Normal file
View File

@ -0,0 +1,115 @@
from typing import Dict
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class ParallelModel(nn.Module):
def __init__(self, process_group: dist.ProcessGroup):
super().__init__()
self.process_group = process_group
self.rank = dist.get_rank(self.process_group)
self.world_size = dist.get_world_size(self.process_group)
class RowParallelLinear(ParallelModel):
def __init__(
self,
process_group: dist.ProcessGroup,
in_features: int,
out_features: int,
bias: bool = True,
reduce_results: bool = True,
):
super().__init__(process_group)
self.in_features = in_features
self.out_features = out_features
self.in_features_per_rank = in_features // self.world_size
self.reduce_results = reduce_results
if in_features % self.world_size != 0:
raise ValueError(
f"in_features must be divisible by world_size. Got {in_features} and {self.world_size}"
)
self.weight = nn.Parameter(torch.empty(out_features, self.in_features_per_rank))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight)
if self.reduce_results:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
if self.bias is not None:
output += self.bias
return output
def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get("weight")
full_bias = state_dict.get("bias")
start_idx = self.rank * self.in_features_per_rank
end_idx = start_idx + self.in_features_per_rank
weight_slice = full_weight[:, start_idx:end_idx]
self.weight.data.copy_(weight_slice)
if self.bias is not None:
self.bias.data.copy_(full_bias)
class ColumnParallelLinear(ParallelModel):
def __init__(
self,
process_group: dist.ProcessGroup,
in_features: int,
out_features: int,
bias: bool = True,
gather_results: bool = True,
):
super().__init__(process_group)
self.in_features = in_features
self.out_features = out_features
self.out_features_per_rank = out_features // self.world_size
self.gather_results = gather_results
if out_features % self.world_size != 0:
raise ValueError(
f"out_features must be divisible by world_size. Got {out_features} and {self.world_size}"
)
self.weight = nn.Parameter(
torch.empty(self.out_features_per_rank, self.in_features)
)
self.bias = (
nn.Parameter(torch.zeros(self.out_features_per_rank)) if bias else None
)
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias)
if self.gather_results:
output_list = [torch.empty_like(output) for _ in range(self.world_size)]
dist.all_gather(output_list, output, group=self.process_group)
output = torch.cat(output_list, dim=-1)
return output
def load_state_dict(self, state_dict: Dict[str, Tensor]):
full_weight = state_dict.get("weight")
full_bias = state_dict.get("bias")
start_idx = self.rank * self.out_features_per_rank
end_idx = start_idx + self.out_features_per_rank
weight_slice = full_weight[start_idx:end_idx, :]
self.weight.data.copy_(weight_slice)
if self.bias is not None:
bias_slice = full_bias[start_idx:end_idx]
self.bias.data.copy_(bias_slice)

166
astrai/parallel/setup.py Normal file
View File

@ -0,0 +1,166 @@
import os
from contextlib import contextmanager
from functools import wraps
from typing import Callable
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def get_current_device():
return os.environ["LOCAL_DEVICE"]
def get_world_size() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank() -> int:
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0
@contextmanager
def setup_parallel(
rank: int,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
):
if dist.is_available() and dist.is_initialized():
yield dist.group.WORLD
return
if world_size <= 1:
yield None
return
device_id = torch.device(device_type, rank)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id)
dist.init_process_group(
rank=rank, world_size=world_size, backend=backend, device_id=device_id
)
try:
if backend == "nccl" and torch.cuda.is_available():
torch.cuda.set_device(device_id)
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.set_device(device_id)
yield dist.group.WORLD
finally:
if dist.is_initialized():
dist.destroy_process_group()
def only_on_rank(rank, sync=False):
"""
decorator to run a function only on a specific rank.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
ret_args = None
if get_rank() == rank:
ret_args = func(*args, **kwargs)
if sync and dist.is_available() and dist.is_initialized():
dist.barrier()
return ret_args
return wrapper
return decorator
def wrapper_spawn_func(
rank: int,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
device_type: str,
func: Callable,
kwargs: dict,
):
try:
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=master_addr,
master_port=master_port,
device_type=device_type,
):
func(**kwargs)
except Exception as e:
print(f"Error in rank {rank}: {e}")
raise
def spawn_parallel_fn(
func: Callable,
world_size: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
start_method: str = "spawn",
**kwargs,
):
# clear environment variables
for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]
if world_size == 1:
device_id = torch.device(device_type, 0)
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return
wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
func,
kwargs,
)
mp.start_processes(
wrapper_spawn_func,
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
)

View File

@ -0,0 +1,14 @@
from astrai.preprocessing.builder import (
BaseMaskBuilder,
MaskBuilderFactory,
SectionedMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
__all__ = [
"BaseMaskBuilder",
"MaskBuilderFactory",
"SectionedMaskBuilder",
"Pipeline",
"filter_by_length",
]

View File

@ -0,0 +1,159 @@
"""Mask building strategies for preprocessing pipeline.
The single :class:`SectionedMaskBuilder` handles all input formats
via declarative ``input.sections`` config.
"""
from abc import ABC, abstractmethod
from typing import Optional
from astrai.factory import BaseFactory
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
Returns ``None`` to skip the item entirely.
"""
...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
val = item.get(domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
def _resolve_action(action: str, role: str, config) -> str:
"""Resolve action to "train" or "mask".
- ``"train"`` / ``"mask"`` literal
- ``"$role"`` look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
"""
if action == "$role":
return config.mask.get(role, config.mask_default)
return action
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder: iterates over ``input.sections`` in order.
Each section specifies a JSONL field + mask action.
Section spec::
{
"field": "messages", # JSONL key
"action": "$role", # "train" | "mask" | "$role"
"template": true, # apply chat_template per message (optional)
"add_special_tokens": false # override encode flag (optional)
}
Example configs::
# Chat
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
# Instruction
{"input": {"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]}}
# Text
{"input": {"sections": [
{"field": "text", "action": "train"}
]}}
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
sections = config.input.sections
if not sections:
return None
all_ids: list[int] = []
loss_mask: list[int] = []
has_template = any(s.get("template") for s in sections)
is_text_config = not has_template and all(
s["action"] == "train" for s in sections
)
if has_template and tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
loss_mask.append(0)
first_section = True
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
add_special = sec.get(
"add_special_tokens", not use_template and first_section
)
if use_template:
messages = item.get(field)
if not isinstance(messages, list) or not messages:
continue
for msg in messages:
role = msg.get("role", "")
act = _resolve_action(action, role, config)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
all_ids.extend(ids)
val = 1 if act == "train" else 0
loss_mask.extend([val] * len(ids))
else:
text = str(item.get(field, ""))
if not text.strip():
continue
if is_text_config:
pp = config.preprocessing
if pp.min_chars > 0 and len(text) < pp.min_chars:
continue
if len(text) > pp.max_chars:
continue
ids = tokenizer.encode(text, add_special_tokens=add_special)
all_ids.extend(ids)
val = 1 if action == "train" else 0
loss_mask.extend([val] * len(ids))
first_section = False
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None
if has_template and len(all_ids) <= 1:
return None
result: dict = {
"sequence": all_ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in loss_mask):
result["loss_mask"] = loss_mask
return result

View File

@ -0,0 +1,141 @@
"""Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
sharding and flush to ``.h5`` / ``.bin`` storage.
"""
import json
import os
from collections import defaultdict
from itertools import chain
from typing import Optional
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import SectionedMaskBuilder
from astrai.tokenize import AutoTokenizer
_STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
Usage::
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
"""
def __init__(
self,
config: PipelineConfig,
input_paths: list[str],
output_dir: str,
tokenizer_path: str,
):
os.makedirs(output_dir, exist_ok=True)
self.config = config
self.paths = input_paths
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = SectionedMaskBuilder()
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
domains: dict = defaultdict(lambda: defaultdict(list))
total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int)
count = 0
pp = self.config.preprocessing
for item in tqdm.tqdm(
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
):
if pp.max_items and count >= pp.max_items:
break
result = self.transform(item)
if result is None:
continue
ids = result.pop("sequence")
if not ids:
continue
domain = result.pop("domain", "__default__")
result["sequence"] = ids
bucket = domains[domain]
for key in list(bucket.keys()):
if key not in result:
bucket[key].append([1] * len(ids))
for key, val in result.items():
bucket[key].append(val)
count += 1
total_tokens += len(ids)
if total_tokens >= self.config.output.max_tokens_per_shard:
self._flush(domains, shard_idx)
domains.clear()
total_tokens = 0
if total_tokens > 0:
self._flush(domains, shard_idx)
print(f"Done. {count} documents tokenized.")
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
tensors = {}
for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get(
self.config.output.dtype.get(key, "int32"), torch.int32
)
tensors[key] = [
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
]
chunk_dir = os.path.join(self.output_dir, domain)
fmt = self.config.output.storage_format
if fmt == "bin":
save_bin(os.path.join(chunk_dir, f"shard_{idx:04d}"), tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors['sequence'][0].numel():,} tokens)"
)

21
astrai/protocols.py Normal file
View File

@ -0,0 +1,21 @@
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class OptimizerProtocol(Protocol):
def step(self, closure=None): ...
def zero_grad(self): ...
@property
def param_groups(self) -> Any: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
@runtime_checkable
class SchedulerProtocol(Protocol):
def step(self): ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
def get_last_lr(self): ...

182
astrai/serialization.py Normal file
View File

@ -0,0 +1,182 @@
import io
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Union
import safetensors.torch as st
import torch
import torch.distributed as dist
from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: Union[str, Path]):
st.save_file(state_dict, str(path))
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: Union[str, Path]):
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
with open(str(path), "r") as f:
return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: Union[str, Path]):
torch.save(obj, str(path))
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str):
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
path = Path(path)
if not broadcast or not dist.is_initialized():
return load_safetensors(path)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(path)
specs = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass
class Checkpoint:
state_dict: Dict[str, Any] = field(default_factory=dict)
epoch: int = 0
iteration: int = 0
extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict)
config: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str):
save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True)
if get_rank() != 0:
return
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
**self.meta,
}
save_json(meta, save_path / _META_FILE)
save_json(self.config, save_path / _CONFIG_FILE)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
for key, value in self.extra.items():
save_torch(value, save_path / f"{key}.pt")
@classmethod
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir)
meta = load_json(save_path / _META_FILE, broadcast)
config = load_json(save_path / _CONFIG_FILE, broadcast)
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
extra = {}
for f in sorted(save_path.iterdir()):
if f.suffix == ".pt":
extra[f.stem] = load_torch(f, broadcast=broadcast)
return cls(
state_dict=state_dict,
epoch=meta.get("epoch", 0),
iteration=meta.get("iteration", 0),
extra=extra,
config=config,
)

View File

@ -0,0 +1,8 @@
from astrai.tokenize.chat_template import ChatTemplate, MessageType
from astrai.tokenize.tokenizer import AutoTokenizer
__all__ = [
"AutoTokenizer",
"ChatTemplate",
"MessageType",
]

View File

@ -0,0 +1,74 @@
from typing import Any, Dict, List, Optional
from jinja2 import Template
type MessageType = Dict[str, Any]
class ChatTemplate:
"""A chat template with Jinja2 rendering support.
Attributes:
name: Unique identifier for the template.
template_str: Jinja2 template string.
description: Optional description.
default_variables: Optional dictionary of default variable values.
special_tokens: Optional dictionary mapping token names to their string values.
"""
def __init__(
self,
name: str = "",
template_str: str = "",
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
):
self.name = name
self.template_str = template_str
self.description = description
self.default_variables = default_variables or {}
self.special_tokens = special_tokens or {}
self._compiled: Template = Template(template_str)
@classmethod
def from_string(
cls,
template_str: str,
description: str = "",
default_variables: Optional[Dict[str, Any]] = None,
special_tokens: Optional[Dict[str, str]] = None,
) -> "ChatTemplate":
"""Create a ChatTemplate instance directly from a template string."""
return cls(
name="",
template_str=template_str,
description=description,
default_variables=default_variables,
special_tokens=special_tokens,
)
def render(
self,
messages: List[MessageType],
system_prompt: Optional[str] = None,
**extra_variables: Any,
) -> str:
"""Render the template with given messages and variables.
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string.
**extra_variables: Additional variables to pass to the template.
These override default_variables and special_tokens.
Returns:
Rendered prompt string.
"""
# Merge default variables, special tokens, and extra variables
variables = {**self.default_variables, **self.special_tokens, **extra_variables}
variables["messages"] = messages
if system_prompt is not None:
variables["system_prompt"] = system_prompt
return self._compiled.render(**variables)

View File

@ -0,0 +1,264 @@
"""
Tokenizer module with implementation and auto-loading support.
"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Union
from tokenizers import Tokenizer
from astrai.tokenize.chat_template import ChatTemplate
class AutoTokenizer:
"""Base tokenizer class with automatic loading support"""
TOKENIZER_CLASSES = {} # Registry for auto-loading
def __init__(
self,
path: Optional[Union[str, Path]] = None,
special_token_map: Optional[Dict[str, str]] = None,
chat_template: Optional[str] = None,
):
self._tokenizer: Tokenizer = None
self._chat_template: Optional[ChatTemplate] = None
self._special_token_map: Optional[Dict] = special_token_map or {}
if chat_template:
self.set_chat_template(chat_template)
if path:
self.load(path)
def load(self, path: Union[str, Path]):
"""Load tokenizer from directory."""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
config_file = path / "tokenizer_config.json"
self._tokenizer = Tokenizer.from_file(str(tokenizer_file))
if config_file.exists():
with open(config_file, "r", encoding="utf-8") as f:
config = json.load(f)
if "special_tokens" in config:
self._special_token_map.update(config["special_tokens"])
# Load chat template from config
if "chat_template" in config:
self.set_chat_template(config["chat_template"])
@classmethod
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory.
Raises:
FileNotFoundError: If tokenizer.json is missing.
RuntimeError: If tokenizer failed to initialize.
"""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
if not tokenizer_file.exists():
raise FileNotFoundError(
f"Tokenizer file not found: {tokenizer_file}. "
"A valid tokenizer.json is required."
)
instance = cls(path)
if instance._tokenizer is None:
raise RuntimeError(
f"Failed to load tokenizer from {path}. "
"The tokenizer.json may be corrupted or incompatible."
)
return instance
def save_pretrained(self, save_path: str):
"""
Save tokenizer to pretrained directory.
Args:
save_path: Path to save the tokenizer
"""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
# Save tokenizer
self._tokenizer.save(str(save_path / "tokenizer.json"))
# Save tokenizer config
config = {}
if self._special_token_map is not None:
config["special_tokens"] = self._special_token_map
if self._chat_template is not None:
config["chat_template"] = self._chat_template.template_str
with open(save_path / "tokenizer_config.json", "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
@classmethod
def register_tokenizer(cls, name: str, tokenizer_class: type):
"""
Register a new tokenizer class.
Args:
name: Name to register the tokenizer class under
tokenizer_class: The tokenizer class to register
"""
cls.TOKENIZER_CLASSES[name] = tokenizer_class
def encode(
self,
tokens: Union[str, List[str]],
out_ids: bool = True,
is_pretokenized: bool = False,
add_special_tokens: bool = True,
) -> List:
"""Encode text to tokens or token IDs."""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
if isinstance(tokens, str):
encoded = self._tokenizer.encode(
tokens,
is_pretokenized=is_pretokenized,
add_special_tokens=add_special_tokens,
)
return encoded.ids if out_ids else encoded.tokens
else:
encoded_list = self._tokenizer.encode_batch(
tokens,
is_pretokenized=is_pretokenized,
add_special_tokens=add_special_tokens,
)
return [
encoded.ids if out_ids else encoded.tokens for encoded in encoded_list
]
def decode(self, tokens: List[int], skip_special_tokens: bool = True) -> str:
"""Decode token IDs to text."""
if self._tokenizer is None:
raise RuntimeError(
"Tokenizer not initialized. Load or create a tokenizer first."
)
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int:
if self._tokenizer is None:
return 0
return self._tokenizer.get_vocab_size()
def __getattr__(self, key: str):
"""
Dynamically intercept special token attribute access.
Supports three forms:
- tokenizer.bos_token returns string
- tokenizer.bos_token_id returns corresponding integer ID
- tokenizer.stop_ids returns list of corresponding integer IDs for all special tokens
"""
# Handle stop_ids - return IDs for all special tokens
if key == "stop_ids":
stop_ids = []
if self._tokenizer is None:
return stop_ids
for val in self._special_token_map.values():
token_id = self._tokenizer.token_to_id(val)
if token_id is not None:
stop_ids.append(token_id)
return stop_ids
# Handle _id suffix (e.g., bos_token_id -> bos_token)
if key.endswith("_id"):
base_attr = key[:-3] # Remove "_id"
token_str = self._special_token_map.get(base_attr)
if token_str is None:
return None
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded, cannot convert token to id.")
return self._tokenizer.token_to_id(token_str)
# Handle regular string attributes
if key in self._special_token_map:
return self._special_token_map.get(key)
# Other attributes trigger default AttributeError
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
@property
def vocab_size(self) -> int:
return len(self)
def set_chat_template(self, template: Union[str, ChatTemplate]):
"""
Set the chat template for the tokenizer.
Args:
template: Either a template name (str) registered in the global registry,
or a ChatTemplate instance, or a Jinja2 template string.
Raises:
KeyError: If template name is not registered.
"""
if isinstance(template, str):
self._chat_template = ChatTemplate.from_string(template)
elif isinstance(template, ChatTemplate):
self._chat_template = template
else:
raise ValueError("Invalid template type, must be str or ChatTemplate.")
def apply_chat_template(
self,
messages: List[Dict[str, str]],
system_prompt: Optional[str] = None,
tokenize: bool = True,
add_generation_prompt: bool = True,
**kwargs,
) -> Union[str, List[int]]:
"""
Apply the chat template to messages and optionally tokenize the result.
Args:
messages: List of message dicts with 'role' and 'content'.
system_prompt: Optional system prompt string (auto-converted to first message).
tokenize: Whether to return token IDs (True) or raw string (False).
add_generation_prompt: Whether to add the generation prompt (default: True).
**kwargs: Additional variables to pass to the template.
Returns:
Either the rendered string or list of token IDs.
Raises:
RuntimeError: If chat template is not set.
"""
if self._chat_template is None:
raise RuntimeError(
"Chat template not set. Use set_chat_template() to set a template first."
)
# Auto-convert system_prompt to first message if provided
if system_prompt:
messages = [{"role": "system", "content": system_prompt}] + list(messages)
# Render the template
rendered = self._chat_template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
**kwargs,
)
if tokenize:
return self.encode(rendered)
return rendered

View File

@ -0,0 +1,24 @@
from astrai.trainer.optim import Muon
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
)
from astrai.trainer.trainer import Trainer
__all__ = [
# Main trainer
"Trainer",
# Optimizer
"Muon",
# Strategy factory
"StrategyFactory",
"BaseStrategy",
# Scheduler factory
"SchedulerFactory",
"BaseScheduler",
# Callback factory
"TrainCallback",
"CallbackFactory",
]

View File

@ -0,0 +1,75 @@
from typing import Any, Callable, Dict
import torch
import torch.nn as nn
def _grad_stat(
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
) -> dict:
results = {}
for name, param in model.named_parameters():
results[name] = default
if param.grad is not None:
results[name] = fn(param.grad.data)
return results
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0)
def grad_std(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.std().item(), 0.0)
def grad_max(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.max().item(), -float("inf"))
def grad_min(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.min().item(), float("inf"))
def grad_mean(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.mean().item(), 0.0)
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0)
def ctx_get_loss(ctx):
return ctx.loss
def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_val_loss(ctx):
return ctx.val_loss
def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model)
def ctx_get_grad_std(ctx):
return grad_std(ctx.model)
def ctx_get_grad_max(ctx):
return grad_max(ctx.model)
def ctx_get_grad_min(ctx):
return grad_min(ctx.model)
def ctx_get_grad_mean(ctx):
return grad_mean(ctx.model)
def ctx_get_grad_nan_num(ctx):
return grad_nan_num(ctx.model)

143
astrai/trainer/optim.py Normal file
View File

@ -0,0 +1,143 @@
import torch
from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2
X = G
scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale
if steps == 0:
return X
a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * (A @ B)
return X
class Muon(Optimizer):
def __init__(
self,
params,
lr: float = 2e-3,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
ns_steps: int = 5,
adamw_lr: float = None,
adamw_betas: tuple = (0.9, 0.95),
adamw_eps: float = 1e-8,
adamw_wd: float = 0.0,
):
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
ns_steps=ns_steps,
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
adamw_betas=adamw_betas,
adamw_eps=adamw_eps,
adamw_wd=adamw_wd,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_2d, params_1d = [], []
grads_2d, grads_1d = [], []
for p in group["params"]:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2:
params_2d.append(p)
grads_2d.append(p.grad)
else:
params_1d.append(p)
grads_1d.append(p.grad)
if params_2d:
self._muon_update_foreach(params_2d, grads_2d, group)
if params_1d:
self._adamw_update_foreach(params_1d, grads_1d, group)
return loss
def _muon_update_foreach(self, params_2d, grads_2d, group):
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
if wd != 0:
torch._foreach_mul_(params_2d, 1 - lr * wd)
if nesterov:
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
bufs = []
for p, grad in zip(params_2d, grads_2d):
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
bufs.append(state["momentum_buffer"])
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
for p, buf in zip(params_2d, bufs):
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update_foreach(self, params_1d, grads_1d, group):
lr = group["adamw_lr"]
betas = group["adamw_betas"]
eps = group["adamw_eps"]
wd = group["adamw_wd"]
steps: list[int] = []
exp_avgs, exp_avg_sqs = [], []
has_state = []
for p in params_1d:
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
has_state.append(False)
else:
has_state.append(True)
state["step"] += 1
steps.append(state["step"])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
beta1, beta2 = betas
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
bias_correction1 = [1 - beta1**s for s in steps]
bias_correction2 = [1 - beta2**s for s in steps]
if wd != 0:
torch._foreach_mul_(params_1d, 1 - lr * wd)
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
denom = torch._foreach_sqrt(denom)
torch._foreach_add_(denom, eps)
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)

194
astrai/trainer/schedule.py Normal file
View File

@ -0,0 +1,194 @@
"""Learning rate scheduler implementations with factory pattern."""
import math
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler
from astrai.factory import BaseFactory
class BaseScheduler(LRScheduler, ABC):
"""Base scheduler class for all other schedulers."""
def __init__(self, optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch)
@abstractmethod
def get_lr(self) -> List[float]:
"""Calculate the current learning rate."""
raise NotImplementedError
def state_dict(self) -> Dict[str, Any]:
return super().state_dict()
def load_state_dict(self, state_dict: Dict[str, Any]):
super().load_state_dict(state_dict)
class SchedulerFactory(BaseFactory["BaseScheduler"]):
"""Factory class for creating learning rate schedulers.
Supports decorator-based registration for extensible scheduler types.
Also supports creation from ScheduleConfig objects.
Example usage:
@SchedulerFactory.register("custom")
class CustomScheduler(BaseScheduler):
...
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
"""
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
@classmethod
def create(
cls, optimizer, schedule_type: str = "none", **kwargs
) -> "BaseScheduler":
"""Create a scheduler instance by type name.
Args:
optimizer: PyTorch optimizer
schedule_type: Type of scheduler ("cosine", "sgdr")
**kwargs: Arguments passed to the scheduler constructor
Returns:
Scheduler instance
"""
return super().create(schedule_type, optimizer, **kwargs)
@classmethod
def available_types(cls) -> list:
"""Return list of registered scheduler type names."""
return cls.list_registered()
# ----------- Scheduler implementations -----------
@SchedulerFactory.register("cosine")
class CosineScheduler(BaseScheduler):
"""Cosine decay scheduler with warmup, implemented as PyTorch LRScheduler."""
def __init__(
self,
optimizer,
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.lr_decay_steps = lr_decay_steps
self.min_rate = min_rate
self.total_steps = warmup_steps + lr_decay_steps
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
# warmup
if self.last_epoch < self.warmup_steps:
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
return [base_lr * warmup_factor for base_lr in self.base_lrs]
# cosine decay
decay_progress = (self.last_epoch - self.warmup_steps) / self.lr_decay_steps
decay_progress = min(decay_progress, 1.0)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * decay_progress))
decay_factor = max(self.min_rate, cosine_decay)
return [base_lr * decay_factor for base_lr in self.base_lrs]
def state_dict(self):
state = super().state_dict()
state.update(
{
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.lr_decay_steps,
"min_rate": self.min_rate,
"total_steps": self.total_steps,
}
)
return state
def load_state_dict(self, state_dict):
self.warmup_steps = state_dict.pop("warmup_steps")
self.lr_decay_steps = state_dict.pop("lr_decay_steps")
self.min_rate = state_dict.pop("min_rate")
self.total_steps = state_dict.pop("total_steps")
super().load_state_dict(state_dict)
@SchedulerFactory.register("sgdr")
class SGDRScheduler(BaseScheduler):
"""SGDR (Stochastic Gradient Descent with Warm Restarts) scheduler."""
def __init__(
self,
optimizer,
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
self.cycle_length = cycle_length
self.min_rate = min_rate
self.t_mult = t_mult
super().__init__(optimizer, last_epoch)
def get_lr(self):
# warmup
if self.last_epoch < self.warmup_steps:
warmup_factor = max(self.min_rate, self.last_epoch / self.warmup_steps)
return [base_lr * warmup_factor for base_lr in self.base_lrs]
# SGDR
steps_since_warmup = self.last_epoch - self.warmup_steps
# 1. Calculate current cycle and position within cycle
current_cycle_length = self.cycle_length
total_cycles_length = 0
cycle_num = 0
while total_cycles_length + current_cycle_length <= steps_since_warmup:
total_cycles_length += current_cycle_length
current_cycle_length *= self.t_mult
cycle_num += 1
steps_in_cycle = steps_since_warmup - total_cycles_length
# 2. Cosine annealing within the current cycle
cosine_factor = 0.5 * (
1 + math.cos(math.pi * steps_in_cycle / current_cycle_length)
)
learning_rate_factor = self.min_rate + (1 - self.min_rate) * cosine_factor
return [base_lr * learning_rate_factor for base_lr in self.base_lrs]
def state_dict(self):
"""Returns the state of the scheduler as a dict."""
state = super().state_dict()
state.update(
{
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult,
}
)
return state
def load_state_dict(self, state_dict):
"""Loads the scheduler's state."""
self.warmup_steps = state_dict.pop("warmup_steps")
self.cycle_length = state_dict.pop("cycle_length")
self.min_rate = state_dict.pop("min_rate")
self.t_mult = state_dict.pop("t_mult")
super().load_state_dict(state_dict)

334
astrai/trainer/strategy.py Normal file
View File

@ -0,0 +1,334 @@
"""Training strategy implementations with factory pattern."""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict."""
ref_model = model_fn()
ref_model.load_state_dict(state_dict)
ref_model.requires_grad_(False)
ref_model.eval()
return ref_model
def move_to_device(batch: Dict[str, Tensor], device: str) -> Any:
"""Move batch tensors to specified device with non-blocking transfer."""
return {key: value.to(device, non_blocking=True) for key, value in batch.items()}
def get_logprobs(
model: Union[nn.Module, Callable[..., Dict[str, Tensor]]],
input_ids: Tensor,
mask: Tensor,
reduction: str,
):
"""Compute token-wise log probabilities from model outputs.
Args:
model: The language model
input_ids: Input token IDs of shape [batch_size, seq_len]
mask: Attention mask of shape [batch_size, seq_len]
reduction: How to reduce over sequence dimension ("mean", "sum", "none")
Returns:
Log probabilities with reduction applied over sequence dimension
"""
allowed_reductions = ["mean", "sum", "none"]
if reduction not in allowed_reductions:
raise ValueError(
f"reduction must be one of {allowed_reductions}, got '{reduction}'"
)
shifted_input_ids = input_ids[:, 1:]
shifted_mask = mask[:, 1:]
logits = model(input_ids[:, :-1], mask[:, :-1])["logits"]
log_probs = torch.log_softmax(logits.float(), dim=-1)
token_logprobs = torch.gather(
log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)
if reduction == "mean":
return (token_logprobs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
dim=-1
).clamp(min=1.0)
elif reduction == "sum":
return (token_logprobs * shifted_mask).sum(dim=-1)
else:
return token_logprobs * shifted_mask
class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
def __init__(
self, model: Union[Callable[..., Dict[str, Tensor]]], device: str, **kwargs
):
self.model = model
self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
"""Compute loss for the given batch.
Args:
batch: Dictionary containing batch tensors
Returns:
Computed loss tensor
"""
raise NotImplementedError
def __call__(self, batch: Dict[str, Tensor]) -> Tensor:
"""Allow calling strategy directly as a callable."""
return self.compute_loss(batch)
class StrategyFactory(BaseFactory["BaseStrategy"]):
"""Factory class for creating training strategy instances.
Supports decorator-based registration for extensible strategy types.
All default strategies (seq, sft, dpo, grpo) are automatically registered.
Example usage:
@StrategyFactory.register("custom")
class CustomStrategy(BaseStrategy):
...
strategy = StrategyFactory.create("custom", model, device)
"""
@classmethod
def _validate_component(cls, strategy_cls: type):
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
@classmethod
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type.
Args:
train_type: Type of training ("seq", "sft", "dpo", "grpo")
model: Model instance for the strategy
device: Device to run the strategy on
**kwargs: Additional arguments passed to strategy constructor
Returns:
Strategy instance
"""
return super().create(train_type, model, device, **kwargs)
@classmethod
def available_strategies(cls) -> list:
"""Return list of registered strategy names."""
return cls.list_registered()
# ============== Strategy Classes ==============
# All strategies are registered at class definition time using the decorator
@StrategyFactory.register("seq")
class SEQStrategy(BaseStrategy):
"""Standard next-token prediction training strategy.
Computes cross-entropy loss for next token prediction.
"""
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
logits = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
label_smoothing=self.label_smoothing,
)
return loss
@StrategyFactory.register("sft")
class SFTStrategy(BaseStrategy):
"""Supervised Fine-tuning strategy with loss masking.
Applies cross-entropy loss only to tokens where loss_mask is True.
"""
def __init__(self, model, device, label_smoothing: float = 0.0, **kwargs):
super().__init__(model, device, **kwargs)
self.label_smoothing = label_smoothing
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
input_ids, target_ids, loss_mask = (
batch["input_ids"],
batch["target_ids"],
batch["loss_mask"],
)
ignore_index = -100
logits = self.model(input_ids=input_ids)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),
target=target_ids.flatten(),
ignore_index=ignore_index,
label_smoothing=self.label_smoothing,
)
return loss
@StrategyFactory.register("dpo")
class DPOStrategy(BaseStrategy):
"""Direct Preference Optimization strategy.
Implements the DPO loss from the paper "Direct Preference Optimization".
Uses a reference model to compute KL divergence penalty.
"""
def __init__(
self,
model: nn.Module,
device: str,
beta: float = 0.1,
reduction: str = "mean",
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.beta = beta
self.reduction = reduction
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
batch = move_to_device(batch, self.device)
chosen_ids, rejected_ids = batch["chosen"], batch["rejected"]
chosen_mask, rejected_mask = batch["chosen_mask"], batch["rejected_mask"]
concat_ids = torch.cat([chosen_ids, rejected_ids], dim=0)
concat_mask = torch.cat([chosen_mask, rejected_mask], dim=0)
log_pi = get_logprobs(self.model, concat_ids, concat_mask, self.reduction)
with torch.no_grad():
log_ref = get_logprobs(
self.ref_model, concat_ids, concat_mask, self.reduction
)
log_pi_chosen = log_pi[: chosen_ids.shape[0]]
log_pi_rejected = log_pi[chosen_ids.shape[0] :]
log_ref_chosen = log_ref[: chosen_ids.shape[0]]
log_ref_rejected = log_ref[chosen_ids.shape[0] :]
pi_log_ratio = log_pi_chosen - log_pi_rejected
ref_log_ratio = log_ref_chosen - log_ref_rejected
ratio_diff = pi_log_ratio - ref_log_ratio
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
return dpo_loss
@StrategyFactory.register("grpo")
class GRPOStrategy(BaseStrategy):
"""Group Relative Policy Optimization strategy.
On-policy GRPO following DeepSeek-R1: the policy model is updated while
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
"""
def __init__(
self,
model: nn.Module,
device: str,
clip_eps: float = 0.2,
kl_coef: float = 0.01,
group_size: int = 4,
reduction: str = "mean",
sync_interval: int = 200,
**kwargs,
):
super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.clip_eps = clip_eps
self.kl_coef = kl_coef
self.group_size = group_size
self.reduction = reduction
self.sync_interval = sync_interval
self._step = 0
def sync_ref_model(self):
"""Copy current model weights to ref model."""
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1
if self._step % self.sync_interval == 0:
self.sync_ref_model()
batch = move_to_device(batch, self.device)
prompts = batch["prompts"]
responses = batch["responses"]
masks = batch["masks"]
rewards = batch["rewards"]
batch_size, group_size, response_len = responses.shape
responses_flat = responses.view(-1, response_len)
masks_flat = masks.view(-1, response_len)
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
log_probs_policy = get_logprobs(
self.model, full_sequences, full_masks, self.reduction
)
log_probs_policy = log_probs_policy.view(batch_size, group_size)
with torch.no_grad():
log_probs_ref = get_logprobs(
self.ref_model, full_sequences, full_masks, self.reduction
)
log_probs_ref = log_probs_ref.view(batch_size, group_size)
eps = torch.finfo(log_probs_policy.dtype).eps
mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + eps)
ratio = torch.exp(log_probs_policy - log_probs_ref)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
kl_penalty = self.kl_coef * (log_probs_policy - log_probs_ref).square().mean()
total_loss = policy_loss + kl_penalty
return total_loss

View File

@ -0,0 +1,333 @@
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from tqdm import tqdm
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device, get_rank
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
ctx_get_grad_mean,
ctx_get_grad_min,
ctx_get_grad_nan_num,
ctx_get_grad_norm,
ctx_get_grad_std,
ctx_get_loss,
ctx_get_lr,
ctx_get_val_loss,
)
from astrai.trainer.train_context import TrainContext
logger = logging.getLogger(__name__)
@runtime_checkable
class TrainCallback(Protocol):
"""
Callback interface for trainer.
"""
def on_train_begin(self, context: TrainContext):
"""Called at the beginning of training."""
def on_train_end(self, context: TrainContext):
"""Called at the end of training."""
def on_epoch_begin(self, context: TrainContext):
"""Called at the beginning of each epoch."""
def on_epoch_end(self, context: TrainContext):
"""Called at the end of each epoch."""
def on_batch_begin(self, context: TrainContext):
"""Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext):
"""Called at the end of each batch."""
def on_optimizer_step(self, context: TrainContext):
"""Called on every optimizer step (sync step only)."""
def on_error(self, context: TrainContext):
"""Called when an error occurs during training."""
class CallbackFactory(BaseFactory[TrainCallback]):
"""Factory for registering and creating training callbacks.
Example:
@CallbackFactory.register("my_callback")
class MyCallback(TrainCallback):
...
callback = CallbackFactory.create("my_callback", **kwargs)
"""
@CallbackFactory.register("gradient_clipping")
class GradientClippingCallback(TrainCallback):
"""
Gradient clipping callback for trainer.
"""
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_optimizer_step(self, context: TrainContext):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@CallbackFactory.register("gradient_checkpointing")
class GradientCheckpointingCallback(TrainCallback):
"""
Activation checkpointing callback trades compute for memory
by recomputing specified module activations during the backward pass.
Args:
modules: Module types to apply checkpointing to.
"""
def __init__(self, modules: Optional[List[type]] = None):
self.modules = tuple(modules) if modules else ()
def _enable(self, module: nn.Module):
if self.modules and isinstance(module, self.modules):
fn = module.forward
module._original_forward = fn
module.forward = lambda *a, **kw: torch_checkpoint(
fn, *a, use_reentrant=False, **kw
)
@staticmethod
def _disable(module: nn.Module):
if hasattr(module, "_original_forward"):
module.forward = module._original_forward
del module._original_forward
def on_train_begin(self, context: TrainContext):
context.model.apply(self._enable)
logger.info("Gradient checkpointing enabled")
def on_train_end(self, context: TrainContext):
context.model.apply(self._disable)
@CallbackFactory.register("checkpoint")
class CheckpointCallback(TrainCallback):
"""
Checkpoint callback for trainer.
"""
extra_keys = ("optimizer", "scheduler")
def __init__(
self,
save_dir: str,
interval: int,
weight_only: bool = False,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext):
state_dict = context.executor.unwrap_model(context.model)
self.last_ckpt_iter = context.iteration
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
config=context.model_config,
)
context.checkpoint.save(save_path)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
def on_train_end(self, context: TrainContext):
if context.iteration != self.last_ckpt_iter:
self._save_checkpoint(context)
def on_error(self, context: TrainContext):
self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):
"""
Progress bar callback for trainer.
"""
def __init__(
self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None
):
self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None
@only_on_rank(0)
def on_epoch_begin(self, context: TrainContext):
self.progress_bar = tqdm(
context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True,
file=self.file or sys.stdout,
)
@only_on_rank(0)
def on_batch_end(self, context: TrainContext):
postfix = {
"loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
}
if context.val_loss > 0:
postfix["val_loss"] = f"{context.val_loss:.4f}"
self.progress_bar.set_postfix(postfix)
self.progress_bar.update(1)
@only_on_rank(0)
def on_epoch_end(self, context: TrainContext):
_ = context
if self.progress_bar:
self.progress_bar.close()
@CallbackFactory.register("metric_logger")
class MetricLoggerCallback(TrainCallback):
def __init__(
self,
log_dir: str,
save_interval: int,
log_interval: int = 10,
metrics: List[str] = None,
):
self.last_log_iter = 0
self.save_interval = save_interval
self.log_interval = log_interval
self.metrics = metrics or ["loss", "lr"]
self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs"
self.log_dir.mkdir(parents=True, exist_ok=True)
self.log_cache = []
self._metric_funcs = {
"loss": ctx_get_loss,
"lr": ctx_get_lr,
"val_loss": ctx_get_val_loss,
"grad_norm": ctx_get_grad_norm,
"grad_std": ctx_get_grad_std,
"grad_max": ctx_get_grad_max,
"grad_min": ctx_get_grad_min,
"grad_mean": ctx_get_grad_mean,
"grad_nan_num": ctx_get_grad_nan_num,
}
def _get_log_data(self, context: TrainContext):
return {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"epoch": context.epoch,
"iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics},
}
@only_on_rank(0)
def _add_log(self, log_data):
self.log_cache.append(log_data)
@only_on_rank(0)
def _save_log(self, epoch, iter):
log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl"
with open(log_file, "w") as f:
for log in self.log_cache:
f.write(json.dumps(log) + "\n")
def on_batch_end(self, context):
if context.iteration % self.log_interval == 0:
log_data = self._get_log_data(context)
self._add_log(log_data)
if context.iteration - self.last_log_iter >= self.save_interval:
self._save_log(context.epoch, context.iteration)
self.last_log_iter = context.iteration
def on_train_end(self, context):
if context.iteration != self.last_log_iter:
self._save_log(context.epoch, context.iteration)
def on_error(self, context):
self._save_log(context.epoch, context.iteration)
@CallbackFactory.register("validation")
class ValidationCallback(TrainCallback):
def _run_validation(self, context: TrainContext):
context.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in context.val_dataloader:
loss = context.strategy(batch)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
if context.world_size > 1 and dist.is_initialized():
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
avg_loss = loss_tensor.item()
context.val_loss = avg_loss
context.model.train()
step_count = context.iteration // context.config.grad_accum_steps
logger.info(
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_optimizer_step(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config
if cfg.val_step <= 0:
return
step_count = context.iteration // cfg.grad_accum_steps
if step_count % cfg.val_step == 0:
self._run_validation(context)

View File

@ -0,0 +1,170 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self
import torch.nn as nn
from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler
from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_json, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@dataclass
class TrainContext:
model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None)
optimizer: OptimizerProtocol = field(default=None)
scheduler: SchedulerProtocol = field(default=None)
checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
model_config: dict = field(default_factory=dict)
executor: BaseExecutor = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
loss: float = field(default=0.0)
val_dataloader: DataLoader = field(default=None)
val_loss: float = field(default=0.0)
world_size: int = field(default=1)
rank: int = field(default=0)
kwargs: dict = field(default_factory=dict)
class TrainContextBuilder:
def __init__(
self,
config: TrainConfig,
):
self.config = config
self._resume_dir: Optional[str] = None
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._resume_dir = resume_dir
return self
def build(self) -> TrainContext:
cfg = self.config
device = get_current_device()
executor = ExecutorFactory.create(
cfg.parallel_mode,
grad_accum_steps=cfg.grad_accum_steps,
**cfg.executor_kwargs,
)
model = cfg.model_fn()
model = model.to(device=device)
model_config = {}
if self._resume_dir:
config_path = Path(self._resume_dir) / "config.json"
if config_path.exists():
model_config = load_json(config_path)
if not model_config and hasattr(model, "config"):
model_config = model.config.to_dict()
context = TrainContext(
model=model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
model_config=model_config,
executor=executor,
)
if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict
if checkpoint.config:
context.model_config = checkpoint.config
else:
checkpoint = None
state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = cfg.start_epoch
context.iteration = cfg.start_batch
context.checkpoint = checkpoint
if cfg.lora is not None:
inject_lora(
model,
r=cfg.lora.r,
alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules),
)
context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
start_epoch=context.epoch,
start_iter=sampler_offset,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_per_device,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
if cfg.val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=cfg.val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
cfg.val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
model=context.model,
train_type=cfg.strategy,
device=device,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs,
)
return context

109
astrai/trainer/trainer.py Normal file
View File

@ -0,0 +1,109 @@
import logging
from typing import List, Optional
from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
logger = logging.getLogger(__name__)
class Trainer:
def __init__(
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
):
self.train_config = train_config
default_callbacks = self._get_default_callbacks()
self.callbacks = (
default_callbacks + callbacks if callbacks else default_callbacks
)
def _get_default_callbacks(self) -> List[TrainCallback]:
cfg = self.train_config
callbacks = [
CallbackFactory.create(
"gradient_checkpointing",
modules=cfg.gradient_checkpointing_modules,
),
CallbackFactory.create(
"checkpoint",
cfg.ckpt_dir,
cfg.ckpt_interval,
),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
]
return callbacks
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(context)
def _trainer_loop(self, resume_dir: Optional[str] = None):
context = (
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
)
executor = context.executor
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
for epoch in range(context.epoch, context.config.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
with executor.accumulate(context.model):
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if executor.sync_gradients:
self._call_callbacks("on_optimizer_step", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context)
except Exception as e:
logger.error("Training failed: %s", str(e), exc_info=True)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, resume_dir: Optional[str] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
resume_dir=resume_dir,
)

View File

@ -1,215 +0,0 @@
import torch
from typing import Dict, Any
from dataclasses import dataclass
from khaosz.core.transformer import TransformerConfig, Transformer
@dataclass
class BenchmarkResult:
total_tokens: int
total_time: float
tokens_per_second: float
metadata: Dict[str, Any]
class GenerationBenchmark:
def __init__(
self,
config: TransformerConfig,
device: str = "cuda",
dtype: torch.dtype = torch.float16
):
self.config = config
self.device = device
self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval()
def _initialize_kv_cache(self, batch_size: int, max_len: int) -> list:
"""初始化KV缓存"""
kv_cache = []
head_dim = self.config.n_dim // self.config.n_head
for _ in range(self.config.n_layer):
k_cache = torch.zeros(
(batch_size, max_len, self.config.n_kvhead, head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(batch_size, max_len, self.config.n_kvhead, head_dim),
device=self.device, dtype=self.dtype
)
kv_cache.append((k_cache, v_cache))
return kv_cache
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint(
low=0,
high=self.config.vocab_size,
size=(batch_size, prompt_length),
device=self.device,
dtype=torch.long
)
gen_ids = torch.randint(
low=0,
high=self.config.vocab_size,
size=(batch_size, total_length - prompt_length),
device=self.device,
dtype=torch.long
)
return prompt_ids, gen_ids
@torch.inference_mode()
def run_prefill_benchmark(
self,
batch_size: int = 1,
prompt_length: int = 512,
num_trials: int = 10,
) -> BenchmarkResult:
for _ in range(3):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
_ = self.model(prompt_ids)
torch.cuda.synchronize()
total_time = 0.0
total_tokens = batch_size * prompt_length * num_trials
for trial in range(num_trials):
prompt_ids, _ = self._prepare_inputs(batch_size, prompt_length, prompt_length)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
_ = self.model(prompt_ids)
end_event.record()
torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)")
return BenchmarkResult(
total_tokens=total_tokens,
total_time=total_time,
tokens_per_second=total_tokens / total_time,
metadata={
"benchmark_type": "prefill",
"batch_size": batch_size,
"prompt_length": prompt_length,
"dtype": self.dtype,
"device": self.device,
}
)
@torch.inference_mode()
def run_decoding_benchmark(
self,
batch_size: int = 1,
prompt_length: int = 512,
gen_length: int = 128,
num_trials: int = 5,
) -> BenchmarkResult:
total_time = 0.0
total_tokens = batch_size * gen_length * num_trials
for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(batch_size, prompt_length, prompt_length + gen_length)
kv_cache = self._initialize_kv_cache(batch_size, self.config.m_len)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
current_pos = prompt_length
for i in range(gen_length):
input_token = gen_ids[:, i:i+1]
_ = self.model(input_token, persistent_key_values=kv_cache, start_pos=current_pos)
current_pos += 1
end_event.record()
torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time
print(f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)")
return BenchmarkResult(
total_tokens=total_tokens,
total_time=total_time,
tokens_per_second=total_tokens / total_time,
metadata={
"benchmark_type": "generation",
"batch_size": batch_size,
"prompt_length": prompt_length,
"gen_length": gen_length,
"dtype": self.dtype,
"device": self.device,
}
)
def print_benchmark_result(result: BenchmarkResult):
"""打印基准测试结果"""
benchmark_type = result.metadata["benchmark_type"]
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
if benchmark_type == "prefill":
print(f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}")
elif benchmark_type == "generation":
print(f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}")
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80)
if __name__ == "__main__":
config = TransformerConfig(
vocab_size=10000,
n_dim=1536,
n_head=24,
n_kvhead=4,
d_ffn=6912,
m_len=2048,
n_layer=24,
norm_eps=1e-5,
)
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(
batch_size=4,
prompt_length=512,
num_trials=5
)
print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark(
batch_size=4,
prompt_length=512,
gen_length=128,
num_trials=5
)
print_benchmark_result(gen_result)

44
docker-compose.yml Normal file
View File

@ -0,0 +1,44 @@
services:
server:
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
command: python -m scripts.tools.server --port 8000 --device cuda
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
restart: unless-stopped
server-cpu:
profiles: [cpu]
build:
context: .
dockerfile: Dockerfile
user: "${UID:-1000}:${GID:-1000}"
ports:
- "8000:8000"
volumes:
- ./params:/app/params:ro
command: python -m scripts.tools.server --port 8000 --device cpu
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 120s
restart: unless-stopped

View File

@ -1,101 +0,0 @@
import os
import torch
import json
import torch
import argparse
from khaosz import Khaosz
from typing import List
from tqdm import tqdm
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
def batch_generate(
model: Khaosz,
queries: List[str],
temperature: float,
top_k: int,
top_p: float,
batch_size: int,
) -> List:
assert batch_size > 0
sorted_queries = sorted(queries, key=lambda x: len(x), reverse=True)
original_indices = {query: idx for idx, query in enumerate(queries)}
responses = [None] * len(queries)
total_batches = (len(sorted_queries) + batch_size - 1) // batch_size
for i in tqdm(range(0, total_batches * batch_size, batch_size), desc="Generating responses"):
batch_queries = sorted_queries[i: min(i + batch_size, len(queries))]
if not isinstance(batch_queries, list):
batch_queries = [batch_queries]
batch_responses = model.batch_generate(
queries=batch_queries,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
for batch_query, batch_response in zip(batch_queries, batch_responses):
print(f"Q: {batch_query[:50]} \nR: {batch_response[:50]})")
for query, response in zip(batch_queries, batch_responses):
original_idx = original_indices[query]
responses[original_idx] = response
return responses
def processor(
model: Khaosz,
input_json_file: str,
output_json_file: str,
batch_size: int,
temperature: float,
top_p: float,
top_k: int,
question_key: str="question",
):
with open(input_json_file, "r", encoding='utf-8') as f:
input_dict = [json.loads(line) for line in f]
queries = [item[question_key] for item in input_dict]
output_dict = batch_generate(
model=model,
queries=queries,
temperature=temperature,
top_k=top_k,
top_p=top_p,
batch_size=batch_size
)
with open(output_json_file, "w", encoding='utf-8') as f:
json.dump(output_dict, f, indent=4, ensure_ascii=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run generate with a Khaosz model.")
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory.")
parser.add_argument("--input_json_file", type=str, required=True, help="Path to the input JSONL file.")
parser.add_argument("--output_json_file", type=str, required=True, help="Path to the output JSONL file.")
parser.add_argument("--question_key", type=str, default="question", help="Key for the question in the input JSON.")
parser.add_argument("--temperature", type=float, default=0.60, help="Temperature for generating responses.")
parser.add_argument("--top_p", type=float, default=0.95, help="Top-p value for generating responses.")
parser.add_argument("--top_k", type=int, default=30, help="Top-k value for generating responses.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generating responses.")
args = parser.parse_args()
model = Khaosz(args.model_dir).to(device='cuda', dtype=torch.bfloat16)
processor(
model,
input_json_file=args.input_json_file,
output_json_file=args.output_json_file,
question_key=args.question_key,
batch_size=args.batch_size,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p
)

View File

@ -1,56 +0,0 @@
__version__ = "1.2.2"
__author__ = "ViperEkura"
from khaosz.model import Khaosz
from khaosz.core.transformer import Transformer, TransformerConfig
from khaosz.utils.retriever import Retriever
from khaosz.utils.splitter import (
SemanticTextSplitter,
PriorityTextSplitter
)
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.parameter import ParameterLoader
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.trainer import (
Trainer,
DatasetLoader,
TrainConfig,
StrategyFactory,
SchedulerFactory
)
__all__ = [
# model
"Khaosz",
# module
"Transformer",
"TransformerConfig",
"BpeTokenizer",
"ParameterLoader",
"TextGenerator",
"ChatGenerator",
"StreamGenerator",
"BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder",
# trainer
"Trainer",
"DatasetLoader",
"TrainConfig",
"StrategyFactory",
"SchedulerFactory",
# utils
"Retriever",
"SemanticTextSplitter",
"PriorityTextSplitter",
]

View File

@ -1,27 +0,0 @@
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.transformer import Transformer, TransformerConfig
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
__all__ = [
"Transformer",
"TransformerConfig",
"BpeTokenizer",
"ParameterLoader",
"ModelParameter",
"Checkpoint",
"TextGenerator",
"ChatGenerator",
"StreamGenerator",
"BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder"
]

View File

@ -1,568 +0,0 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self
from khaosz.core.parameter import ModelParameter
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
"""
Build prompt for query and history
Args:
query(str): query string
history(Optional[List[Tuple[str, str]]]): history list of query and response
Returns:
str: prompt string
"""
prompt_parts = []
if history is None:
history = []
for his_query, his_response in history:
prompt_parts.append(f"<|user|> {his_query} <|system|> <bos>{his_response}<eos>")
if query is not None:
prompt_parts.append(f"<|user|> {query} <|system|> <bos>")
return "\n".join(prompt_parts)
def pad_sequence(ids_list: List[List[int]], max_ids_len: int, pad_id: int) -> List[List[int]]:
"""
Pad a list of sequences to a fixed length.
Args:
ids_list (List[List[int]]): A list of sequences.
max_ids_len (int): The maximum length of sequences.
pad_id (int): The id to pad sequences.
Returns:
List[List[int]]: A list of padded sequences.
"""
new_ids_list = []
for ids in ids_list:
pad_len = max_ids_len - len(ids)
padded_seq = [pad_id] * pad_len + ids
new_ids_list.append(padded_seq)
return new_ids_list
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf")
) -> Tensor:
"""
Apply sampling strategies to the logits tensor.
Args:
logits (Tensor): The logits tensor.
temperature (float): The temperature parameter.
top_k (int): The top-k parameter.
top_p (float): The top-p parameter.
filter_value (float, optional): The filter value. Defaults to -float("inf").
Returns:
Tensor: The sampled logits tensor.
"""
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class KVCacheManager:
def __init__(
self,
num_layers: int,
batch_size: int,
max_len: int,
num_heads: int,
head_dim: int,
device: torch.device = "cuda",
dtype: torch.dtype = torch.bfloat16
):
self.num_layers = num_layers
self.batch_size = batch_size
self.max_len = max_len
self.num_heads = num_heads
self.head_dim = head_dim
self.device = device
self.dtype = dtype
self._kv_cache: List[Tuple[Tensor, Tensor]] = None
self._seq_mask: Tensor = None
self._initialize()
def _initialize(self):
self._kv_cache = []
for _ in range(self.num_layers):
k_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
v_cache = torch.zeros(
(self.batch_size, self.max_len, self.num_heads, self.head_dim),
device=self.device, dtype=self.dtype
)
self._kv_cache.append((k_cache, v_cache))
self._seq_mask = torch.ones(
(self.batch_size, self.max_len),
device=self.device, dtype=torch.bool
)
def update(self, active_mask: Tensor):
for i in range(self.num_layers):
k_cache, v_cache = self._kv_cache[i]
new_k_cache, new_v_cache = k_cache[active_mask], v_cache[active_mask]
self._kv_cache[i] = (new_k_cache, new_v_cache)
self._seq_mask = self._seq_mask[active_mask]
def reset(self, full_reset=False):
if full_reset:
self._kv_cache = None
self._seq_mask = None
else:
self._initialize()
def set_seq_mask(self, input_ids: Tensor, pad_id: int):
batch_size, seq_len = input_ids.shape
bool_mask = (input_ids != pad_id)
self._seq_mask[: batch_size, : seq_len] = bool_mask
def get_kvcache(self) -> List[Tuple[Tensor, Tensor]]:
return self._kv_cache
def get_seq_mask(self) -> Tensor:
return self._seq_mask
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
self,
input_ids: Tensor,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
if indices is not None:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class TextGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(query)
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
response = self.tokenizer.decode(ids[start_cache_pos:])
return response
class ChatGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
if next_token_id.item() in self.tokenizer.stop_ids:
break
response = self.tokenizer.decode(ids[start_cache_pos:])
cpy_history.append((query, response))
return response, cpy_history
class StreamGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=1,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
ids = self.tokenizer.encode(build_prompt(query, history))
input_ids = torch.tensor([ids], device=device, dtype=torch.long)
cpy_history = history.copy()
start_cache_pos = len(ids)
cur_cache_pos = 0
self.model.eval()
while len(ids) < self.config.m_len:
kv_caches = cache_manager.get_kvcache()
logits, cache_increase = self.compute_logits(
input_ids,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
input_ids = next_token_id
ids.append(next_token_id.item())
cur_cache_pos += cache_increase
response = self.tokenizer.decode(ids[start_cache_pos:])
yield response, cpy_history + [(query, response)]
if next_token_id.item() in self.tokenizer.stop_ids:
yield response + "\n", cpy_history + [(query, response)]
break
class BatchGenerator(GeneratorCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def generate(
self,
queries: List[str],
histories: List[List[Tuple[str, str]]],
temperature: float,
top_k: int,
top_p: float
) -> List[str]:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
batch_size = len(queries)
if histories is None:
histories = [[] for _ in range(batch_size)]
prompts = [build_prompt(query, history) for query, history in zip(queries, histories)]
ids_list = [self.tokenizer.encode(prompt) for prompt in prompts]
max_ids_len = max(len(ids) for ids in ids_list)
ids_list = pad_sequence(ids_list, max_ids_len, self.tokenizer.pad_id)
device = next(self.model.parameters()).device
cache_manager = KVCacheManager(
num_layers=self.config.n_layer,
batch_size=batch_size,
max_len=self.config.m_len,
num_heads=self.config.n_kvhead,
head_dim=self.config.n_dim // self.config.n_head,
device=device,
)
input_tensor = torch.tensor(ids_list, device=device, dtype=torch.long)
cache_manager.set_seq_mask(input_tensor, self.tokenizer.pad_id)
activate_task_mask = [True] * batch_size
start_cache_pos = max_ids_len
cur_cache_pos = 0
while max_ids_len < self.config.m_len and sum(activate_task_mask) != 0:
kv_caches = cache_manager.get_kvcache()
attn_mask =cache_manager.get_seq_mask()
logits, cache_increase = self.compute_logits(
input_tensor,
attn_mask=attn_mask,
kv_caches=kv_caches,
start_pos=cur_cache_pos
)
cur_cache_pos += cache_increase
logits = apply_sampling_strategies(logits, temperature, top_k, top_p)
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
active_mask = []
c_ids = 0
for i in range(batch_size):
if activate_task_mask[i]:
token = next_token_id[c_ids, :].item()
ids_list[i].append(token)
c_ids += 1
is_active = not token in self.tokenizer.stop_ids
activate_task_mask[i] = is_active
active_mask.append(is_active)
active_mask = torch.tensor(active_mask, device=device, dtype=torch.bool)
cache_manager.update(active_mask)
input_tensor = next_token_id[active_mask, :]
max_ids_len += 1
responses = [str()] * batch_size
for i in range(batch_size):
responses[i] = self.tokenizer.decode(ids_list[i][start_cache_pos:])
histories[i].append((queries[i], responses[i]))
return responses
class RetrievalGenerator(GeneratorCore):
def __init__(self, retriever_parameter: ModelParameter):
super().__init__(retriever_parameter)
def generate(
self,
retrieved: List[str],
query: str,
history: List[Tuple[str, str]],
temperature: float,
top_k: int,
top_p: float,
) -> str:
assert temperature >= 0.0
assert top_k >= 0
assert top_p >= 0.0 and top_p <= 1.0
if history is None:
history = []
retrieved = "\n".join([f"{idx + 1}. {key}" for idx, key in enumerate(retrieved)]) if retrieved else ""
retrieved_query = f"{retrieved}<eos>\n\n根据以上内容回答: {query}" if retrieved else query
parameter = ModelParameter(self.model, self.tokenizer, self.config)
return ChatGenerator(parameter).generate(
retrieved_query,
history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
class EmbeddingEncoder(EmbeddingEncoderCore):
def __init__(self, parameter: ModelParameter):
super().__init__(parameter)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
return super().encode(sentence)

View File

@ -1,237 +0,0 @@
import pickle as pkl
import matplotlib.pyplot as plt
import safetensors.torch as st
import torch.nn as nn
import torch.optim as optim
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Self, Union
from pathlib import Path
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.transformer import TransformerConfig, Transformer
class BaseModelIO:
"""Base class for model I/O operations."""
def __init__(
self,
model: Optional[nn.Module] = None,
tokenizer: Optional[BpeTokenizer] = None,
config: Optional[TransformerConfig] = None
):
self.model = model
self.tokenizer = tokenizer or BpeTokenizer()
self.config = config or TransformerConfig()
def _get_file_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
"""Get standardized file paths for model components."""
dir_path = Path(directory)
return {
"model": dir_path / "model.safetensors",
"config": dir_path / "config.json",
"tokenizer": dir_path / "tokenizer.json"
}
def save_components(self, save_dir: Union[str, Path]):
"""Save core model components."""
paths = self._get_file_paths(save_dir)
paths["model"].parent.mkdir(parents=True, exist_ok=True)
if self.model is not None:
st.save_file(self.model.state_dict(), str(paths["model"]))
self.config.save(str(paths["config"]))
self.tokenizer.save(str(paths["tokenizer"]))
def load_components(self, load_dir: Union[str, Path]) -> Self:
"""Load core model components."""
paths = self._get_file_paths(load_dir)
self.config.load(str(paths["config"]))
self.tokenizer.load(str(paths["tokenizer"]))
if paths["model"].exists():
state_dict = st.load_file(str(paths["model"]))
if self.model is None:
self.model = Transformer(self.config)
self.model.load_state_dict(state_dict)
return self
def to(self, *args, **kwargs) -> Self:
"""Move model to device."""
if self.model is not None:
self.model.to(*args, **kwargs)
return self
@dataclass
class ModelParameter(BaseModelIO):
"""Container for model parameters with serialization capabilities."""
model: Optional[nn.Module] = field(
default=None,
metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."}
)
config: TransformerConfig = field(
default_factory=TransformerConfig,
metadata={"help": "Transformer model configuration."}
)
def save(self, save_dir: Union[str, Path]):
self.save_components(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
return self.load_components(load_dir)
@dataclass
class Checkpoint(BaseModelIO):
"""Extended model parameters with training state."""
model: Optional[nn.Module] = field(
default=None,
metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer,
metadata={"help": "Tokenizer for the model."}
)
config: TransformerConfig = field(
default_factory=TransformerConfig,
metadata={"help": "Transformer model configuration."}
)
optim_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Optimizer state."}
)
sampler_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Sampler state."}
)
loss_list: List[float] = field(
default_factory=list,
metadata={"help": "List of training losses."}
)
def _get_training_paths(self, directory: Union[str, Path]) -> dict[str, Path]:
paths = self._get_file_paths(directory)
paths.update({
"loss_list": paths["model"].parent / "loss.pkl",
"loss_plot": paths["model"].parent / "loss.png",
"optim_state": paths["model"].parent / "optim_state.pkl",
"sampler_state": paths["model"].parent / "sampler_state.pkl"
})
return paths
def save_training_state(self, save_dir: Union[str, Path]):
paths = self._get_training_paths(save_dir)
# Save loss plot
self._plot_loss(str(paths["loss_plot"]))
# Save loss list
with open(str(paths["loss_list"]), "wb") as f:
pkl.dump(self.loss_list, f)
# Save optimizer state
with open(str(paths["optim_state"]), "wb") as f:
pkl.dump(self.optim_state, f)
# Save sampler state
with open(str(paths["sampler_state"]), "wb") as f:
pkl.dump(self.sampler_state, f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
paths = self._get_training_paths(load_dir)
# Load loss list
if paths["loss_list"].exists():
with open(str(paths["loss_list"]), "rb") as f:
self.loss_list = pkl.load(f)
# Load optimizer state
if paths["optim_state"].exists():
with open(str(paths["optim_state"]), "rb") as f:
self.optim_state = pkl.load(f)
# Load sampler state
if paths["sampler_state"].exists():
with open(str(paths["sampler_state"]), "rb") as f:
self.sampler_state = pkl.load(f)
return self
def _plot_loss(self, save_path: str):
"""Plot and save loss curve."""
if not self.loss_list:
return
current_iter = len(self.loss_list)
plt.figure(figsize=(10, 6))
plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {current_iter}")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.grid(True)
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()
def save(self, save_dir: Union[str, Path]):
"""Save complete checkpoint."""
self.save_components(save_dir)
self.save_training_state(save_dir)
def load(self, load_dir: Union[str, Path]) -> Self:
"""Load complete checkpoint."""
self.load_components(load_dir)
self.load_training_state(load_dir)
return self
class ParameterLoader:
"""Factory class for loading model parameters or checkpoints."""
@staticmethod
def load(load_dir: Union[str, Path]) -> Union[ModelParameter, Checkpoint]:
"""Load either ModelParameter or Checkpoint based on directory contents."""
load_dir = Path(load_dir)
# Check for training-specific files
loss_file = load_dir / "loss.pkl"
has_training_data = loss_file.exists()
# Create appropriate instance
if has_training_data:
checkpoint = Checkpoint()
checkpoint.load(str(load_dir))
return checkpoint
else:
params = ModelParameter()
params.load(str(load_dir))
return params
@staticmethod
def create_checkpoint(
model: nn.Module,
tokenizer: BpeTokenizer,
config: TransformerConfig,
loss_list: Optional[list[float]] = None,
optimizer: Optional[optim.Optimizer] = None,
) -> Checkpoint:
"""Convenience method to create a training checkpoint."""
return Checkpoint(
model=model,
tokenizer=tokenizer,
config=config,
loss_list=loss_list or [],
optimizer_state=optimizer
)

View File

@ -1,119 +0,0 @@
from tokenizers import Tokenizer, Encoding
from tokenizers import decoders, processors, normalizers, pre_tokenizers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from typing import List, Union
class BpeTokenizer:
def __init__(self, path=None):
self._control_tokens = ["<bos>", "<eos>", "<pad>"]
self._special_tokens = ["<|user|>", "<|system|>"]
model = BPE()
tokenizer = Tokenizer(model)
tokenizer.normalizer = normalizers.Sequence([
normalizers.NFC()
])
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Punctuation(behavior="isolated"),
pre_tokenizers.Metaspace(prepend_scheme="never"),
pre_tokenizers.Split(pattern=r"(\d+|[a-zA-Z]+|(?:'s|'t|'re|'ve|'m|'ll|'d))", behavior="isolated"),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
tokenizer.decoder = decoders.Sequence([
decoders.ByteLevel(),
decoders.Metaspace(prepend_scheme="never")
])
tokenizer.post_processor = processors.Sequence([
processors.ByteLevel(trim_offsets=False)
])
self._tokenizer = tokenizer
if path is not None:
self._tokenizer = Tokenizer.from_file(path)
def _prepare_trainer(self, vocab_size: int, min_freq: int, reserved_token_size: int) -> tuple:
assert reserved_token_size > len(self._special_tokens)
reserved_tokens = [f"<|rsv{i:02d}|>" for i in range(reserved_token_size - len(self._special_tokens))]
detail_vocab_size = vocab_size - (len(reserved_tokens) + len(self._special_tokens))
alphabet = pre_tokenizers.ByteLevel.alphabet()
min_size = len(alphabet) + len(self._control_tokens)
assert detail_vocab_size > min_size
trainer = BpeTrainer(
vocab_size=detail_vocab_size,
min_frequency=min_freq,
limit_alphabet=detail_vocab_size // 4,
max_token_length=18,
special_tokens=self._control_tokens,
show_progress=True,
initial_alphabet=alphabet,
)
return trainer, detail_vocab_size, reserved_tokens
def train(self, files, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size
)
self._tokenizer.train(files=files, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def train_from_iterator(self, iterator, vocab_size, min_freq, reserved_token_size=100):
trainer, _, reserved_tokens = self._prepare_trainer(
vocab_size=vocab_size,
min_freq=min_freq,
reserved_token_size=reserved_token_size
)
self._tokenizer.train_from_iterator(iterator=iterator, trainer=trainer)
self._tokenizer.add_special_tokens(self._special_tokens + reserved_tokens)
def save(self, path):
self._tokenizer.save(path)
def load(self, path):
self._tokenizer = Tokenizer.from_file(path)
def encode(self, tokens: Union[str, List[str]], out_ids: bool=True, add_special_tokens: bool=False) -> List:
if isinstance(tokens, str):
encoded: Encoding = self._tokenizer.encode(tokens, add_special_tokens=add_special_tokens)
return encoded.ids if out_ids else encoded.tokens
elif isinstance(tokens, list):
encoded_list: List[Encoding] = self._tokenizer.encode_batch(tokens, add_special_tokens=add_special_tokens)
return [encoded.ids if out_ids else encoded.tokens for encoded in encoded_list]
def decode(self, tokens: List[int], skip_special_tokens: bool=True) -> str:
return self._tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def __len__(self) -> int:
return self._tokenizer.get_vocab_size()
@property
def stop_ids(self) -> List[int]:
stop_ids = []
for token in self._control_tokens:
stop_ids.append(self._tokenizer.token_to_id(token))
return stop_ids
@property
def bos_id(self) -> int:
return self._tokenizer.token_to_id("<bos>")
@property
def eos_id(self) -> int:
return self._tokenizer.token_to_id("<eos>")
@property
def pad_id(self) -> int:
return self._tokenizer.token_to_id("<pad>")
@property
def user_id(self) -> int:
return self._tokenizer.token_to_id("<|user|>")
@property
def system_id(self) -> int:
return self._tokenizer.token_to_id("<|system|>")

View File

@ -1,346 +0,0 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import init
from dataclasses import asdict, dataclass
from typing import List, Optional, Self, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_heads, n_rep, head_dim)
.reshape(bs, slen, n_heads * n_rep, head_dim)
)
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: torch.device = "cuda",
) -> torch.Tensor:
"""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (torch.device, optional): The device to use. Defaults to "cuda".
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, device=device).float() / dim)
t = torch.arange(0, max_len, device=device).float()
freqs = torch.outer(t, theta)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
"""
Apply rotary embedding to the input tensor.
Args:
x (Tensor): The input tensor.
freqs_cis (Tensor): The rotary embedding tensor.
Returns:
Tensor: The output tensor.
"""
dtype = x.dtype
seq_len = x.size(1)
x_complex = torch.view_as_complex(x.view(*x.shape[:-1], -1, 2).float())
freqs_cis = freqs_cis.reshape(1, seq_len, 1, -1)
x_out = torch.view_as_real(x_complex * freqs_cis).flatten(3)
return x_out.to(dtype)
def process_attention_mask(
seq_mask: Tensor,
start_pos: int = 0,
seq_len: int = 0,
is_causal: bool = False,
device: torch.device = "cuda",
dtype: torch.dtype = torch.float32
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
start_pos (int): The starting position of the sequence.
seq_len (int): The length of the sequence.
is_causal (bool): Whether the attention is causal or not.
device (torch.device): The device to use.
Returns:
Tensor: The attention mask tensor.
"""
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
causal_mask = torch.tril(
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
diagonal=start_pos
)
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
expanded_mask = expanded_mask & causal_mask
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@dataclass
class TransformerConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None
m_len: Optional[int] = None
norm_eps: Optional[float] = None
d_ffn: Optional[int] = None
# GQA
n_kvhead: Optional[int] = None
def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f:
config: dict = json.load(f)
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str) -> None:
config_dict = asdict(self)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
init.normal_(self.weight, mean=0, std=0.006)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, n_dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(n_dim))
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
x = x.float()
mean_square = torch.mean(torch.pow(x, 2), dim=-1, keepdim=True)
norm = x * torch.rsqrt(mean_square + self.norm_eps)
norm = norm.to(dtype)
out = norm * self.weight
return out
class MLP(nn.Module):
def __init__(self, n_dim: int, d_ffn: int):
super().__init__()
self.up = Linear(n_dim, d_ffn)
self.gate = Linear(n_dim, d_ffn)
self.down = Linear(d_ffn, n_dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
class GQA(nn.Module):
def __init__(
self,
n_dim: int,
n_head: int,
n_kvhead: int,
):
super().__init__()
assert n_dim % n_head == 0
assert n_head % n_kvhead == 0
self.head_dim = n_dim // n_head
self.n_dim = n_dim
self.n_heads = n_head
self.n_kvheads = n_kvhead
self.n_rep = n_head // n_kvhead
self.q_proj = Linear(n_dim, n_head * self.head_dim)
self.k_proj = Linear(n_dim, n_kvhead * self.head_dim)
self.v_proj = Linear(n_dim, n_kvhead * self.head_dim)
self.o_proj = Linear(n_dim, n_dim)
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0
) -> Tensor:
bsz, seq_len, _ = x.size()
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kvheads)
v = self._split_heads(self.v_proj(x), self.n_kvheads)
q, k = apply_rotary_emb(q, freqs_cis), apply_rotary_emb(k, freqs_cis)
if kv_cache is not None:
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, start_pos:start_pos + seq_len] = k
v_cache[:bsz, start_pos:start_pos + seq_len] = v
# get cache
k = k_cache[:bsz, :start_pos + seq_len]
v = v_cache[:bsz, :start_pos + seq_len]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=(mask == None)).permute(0, 2, 1, 3)
out = self.o_proj(sdqa_out.contiguous().view(bsz, seq_len, -1))
return out
def _split_heads(self, x: Tensor, n_heads) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.reshape(batch_size, seq_len, n_heads, self.head_dim)
return x
class DecoderBlock(nn.Module):
def __init__(self, n_dim, n_head, d_ffn, n_kvhead, norm_eps):
super().__init__()
self.attention = GQA(n_dim, n_head, n_kvhead)
self.norm_attn = RMSNorm(n_dim, norm_eps)
self.ffn = MLP(n_dim, d_ffn)
self.norm_ffn = RMSNorm(n_dim, norm_eps)
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0
) -> Tensor:
# attention
attn_output = self.attention(
self.norm_attn(x),
freqs_cis,
attention_mask,
kv_cache,
start_pos
)
x = attn_output + x
# feed forward
x = self.ffn(self.norm_ffn(x)) + x
return x
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([
DecoderBlock(
config.n_dim,
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
init.normal_(self.embedding, mean=0, std=0.02)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0
) -> Tensor:
assert input_ids.ndim == 2
seq_len = input_ids.size(-1)
x = F.embedding(input_ids, self.embedding)
self.freq_cis = self.freq_cis.to(x.device)
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None
attn_mask = process_attention_mask(
input_mask,
start_pos=start_pos,
seq_len=seq_len,
is_causal=has_kvcache,
device=x.device,
dtype=x.dtype
)
for i, layer in enumerate(self.layers):
kv_cache = persistent_key_values[i] if persistent_key_values else None
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding)
return {
"logits": logits,
"hidden_states": hidden_states
}

View File

@ -1,112 +0,0 @@
from torch import Tensor
from typing import List, Tuple, Generator, Union
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.core.parameter import ParameterLoader
class Khaosz:
def __init__(self, model_dir: str):
self.parameter = ParameterLoader.load(model_dir)
def to(self, *args, **kwargs):
self.parameter.to(*args, **kwargs)
return self
def generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = ChatGenerator(self.parameter)
return generator.generate(
query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def batch_generate(
self,
queries: List[str],
histories: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> List[str]:
generator = BatchGenerator(self.parameter)
return generator.generate(
queries,
histories=histories,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def stream_generate(
self,
query: str,
history: List[Tuple[str, str]]=None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> Generator[Tuple[str, List[Tuple[str, str]]], None, None]:
stream_generator = StreamGenerator(self.parameter)
return stream_generator.generate(
query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def retrieve_generate(
self,
retrieved,
query: str,
history: List[Tuple[str, str]] = None,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = RetrievalGenerator(self.parameter)
return generator.generate(
retrieved,
query,
history=history,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def text_generate(
self,
query: str,
temperature: float=0.8,
top_k: int=50,
top_p: float=0.95,
) -> str:
generator = TextGenerator(self.parameter)
return generator.generate(
query,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
encoder = EmbeddingEncoder(self.parameter)
return encoder.encode(sentence)

View File

@ -1,34 +0,0 @@
from khaosz.trainer.data_util import DatasetLoader
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import (
TrainConfig,
CosineScheduleConfig,
SgdrScheduleConfig,
StrategyFactory,
SchedulerFactory
)
from khaosz.trainer.trainer_callback import (
TrainerCallback,
ProgressBarCallback,
CheckpointCallback,
TrainerCallback,
SchedulerCallback
)
__all__ = [
# strategy
"DatasetLoader",
"Trainer",
"TrainConfig",
"CosineScheduleConfig",
"SgdrScheduleConfig",
"StrategyFactory",
"SchedulerFactory",
# callback
"TrainerCallback",
"ProgressBarCallback",
"CheckpointCallback",
"TrainerCallback",
"SchedulerCallback",
]

View File

@ -1,326 +0,0 @@
import torch
import bisect
import pickle as pkl
from abc import ABC, abstractmethod
from torch import Tensor
from torch.utils.data import Dataset, Sampler
from typing import Callable, List, Dict, Literal, Union
MutiSeg = Dict[str, List[Tensor]]
Seg = Dict[str, Tensor]
def load_pkl_files(paths: List[str]):
segments: MutiSeg = {}
total_samples = 0
for path in paths:
with open(path, "rb") as f:
pkl_file: Seg = pkl.load(f)
for key, value in pkl_file.items():
if key not in segments:
segments[key] = []
segments[key].append(value)
first_key = list(pkl_file.keys())[0]
total_samples += pkl_file[first_key].numel()
return segments, total_samples
def build_attention_mask(input_ids: Tensor, user_token_id: int, multi_turn: bool) -> Tensor:
seq_len = input_ids.size(0)
turn_id = input_ids.eq(user_token_id).cumsum(dim=-1)
iq = turn_id.view(seq_len, 1)
ik = turn_id.view(1, seq_len)
# fix the causual attention mask(iq >= ik condition)
seq_mask = (iq >= ik) if multi_turn else (iq == ik)
attention_mask = torch.tril(seq_mask)
# fix the shape (bsz, 1, seq_len, seq_len) unsqueeze for broadcast
return attention_mask.unsqueeze(0)
def build_loss_mask(input_ids: Tensor, bos_token_id: int, eos_token_id: int) -> Tensor:
token_markers = torch.zeros_like(input_ids, dtype=torch.int8)
is_bos_token = input_ids.eq(bos_token_id)
is_eos_token = input_ids.eq(eos_token_id)
# fix the eos_token_id bug(change target_ids to input_ids)
token_markers[is_bos_token] = 1
token_markers[is_eos_token] = -1
cumulative_markers = torch.cumsum(token_markers, dim=-1)
min_cumulative = cumulative_markers.min(dim=-1, keepdim=True).values
loss_mask = cumulative_markers - min_cumulative
return loss_mask.to(dtype=torch.bool)
class BaseSegmentFetcher:
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += len(seg)
self.cum_lengths.append(total)
self.total_length = total if segments else 0
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
if not (0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx - 1)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx - 1)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MutiSegmentFetcher:
def __init__(self, muti_segments: MutiSeg):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in muti_segments.items()
}
def key_fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]) -> Union[Tensor, Seg]:
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.muti_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Union[Tensor, Seg]:
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
class BaseDataset(Dataset, ABC):
def __init__(self, chunk_size: int, device: str):
super().__init__()
self.segments: MutiSeg = {}
self.chunk_size = chunk_size
self.total_samples = 0
self.device = device
def save(self, save_path: str):
keys = list(self.segments.keys())
if not keys:
return
first_item = self.segments[keys[0]]
segment_size = len(first_item)
for i in range(segment_size):
formated_segment = {key: self.segments[key][i] for key in keys}
pkl.dump(formated_segment, open(f"{save_path}_{i}.pkl", "wb"))
def load(self, load_path: Union[str, List[str]]):
paths = [load_path] if isinstance(load_path, str) else load_path
self.segments, self.total_samples = load_pkl_files(paths)
self.fetcher = MutiSegmentFetcher(self.segments)
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Tensor]:
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples // self.chunk_size > 0
return self.total_samples // self.chunk_size
class SeqDataset(BaseDataset):
def __init__(
self,
chunk_size,
device='cuda'
):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx).to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1).to(device=self.device, dtype=torch.long)
return {"input_ids": x, "target_ids": y}
class SftDataset(BaseDataset):
def __init__(
self,
chunk_size,
bos_token_id,
eos_token_id,
user_token_id,
multi_turn=False,
device='cuda'
):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.user_token_id = user_token_id
self.multi_turn = multi_turn
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
x = self._fetch_data(begin_idx, end_idx, "sequence").to(device=self.device, dtype=torch.long)
y = self._fetch_data(begin_idx + 1, end_idx + 1, "sequence").to(device=self.device, dtype=torch.long)
# fix the eos_token_id bug(change target_ids to input_ids)
loss_mask = build_loss_mask(x, self.bos_token_id, self.eos_token_id)
attn_mask = build_attention_mask(x, self.user_token_id, self.multi_turn)
return {"input_ids": x, "target_ids": y, "loss_mask": loss_mask, "attn_mask": attn_mask}
class DpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
start_idx = index * self.chunk_size
end_idx = min(start_idx + self.chunk_size, self.total_samples - 1)
chosen = self._fetch_data(start_idx, end_idx, "chosen").to(device=self.device, dtype=torch.long)
rejected = self._fetch_data(start_idx, end_idx, "rejected").to(device=self.device, dtype=torch.long)
chosen_mask = self._fetch_data(start_idx, end_idx, "chosen_mask").to(device=self.device, dtype=torch.bool)
rejected_mask = self._fetch_data(start_idx, end_idx, "rejected_mask").to(device=self.device, dtype=torch.bool)
return {"chosen": chosen, "rejected": rejected, "chosen_mask": chosen_mask, "rejected_mask": rejected_mask}
class PpoDataset(BaseDataset):
def __init__(self, chunk_size: int, device="cuda"):
super().__init__(chunk_size, device)
self.fetcher = MutiSegmentFetcher(self.segments)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx = index * self.chunk_size
end_idx = min(begin_idx + self.chunk_size, self.total_samples - 1)
input_ids = self._fetch_data(begin_idx, end_idx, "input_ids").to(self.device),
actions = self._fetch_data(begin_idx, end_idx, "actions").to(self.device),
logprobs = self._fetch_data(begin_idx, end_idx, "logprobs").to(self.device),
rewards = self._fetch_data(begin_idx, end_idx, "rewards").to(self.device)
return {"input_ids": input_ids, "actions": actions, "logprobs": logprobs, "rewards": rewards}
class DatasetLoader:
@staticmethod
def load(
train_type: Literal["seq", "sft", "dpo"],
load_path: Union[str, List[str]],
max_len: int,
device: str,
**kwargs
) -> BaseDataset:
dataset_router: Dict[str, Callable[[int, torch.device], BaseDataset]] = {
"seq": lambda m_len, device: SeqDataset(m_len, device=device),
"sft": lambda m_len, device: SftDataset(
m_len,
device=device,
bos_token_id=kwargs.get("bos_token_id"),
eos_token_id=kwargs.get("eos_token_id"),
user_token_id=kwargs.get("user_token_id"),
multi_turn=kwargs.get("multi_turn")
),
"dpo": lambda m_len, device: DpoDataset(m_len, device=device),
}
dataset = dataset_router[train_type](max_len, device)
dataset.load(load_path)
return dataset
class RandomSampler(Sampler[int]):
def __init__(self, data_source, generator=None, seed=42):
self.data_source = data_source
self.seed = seed
self.epoch = 0
self.current_iter = 0
self._indices = None
if generator is None:
self.generator = torch.Generator()
self.generator.manual_seed(seed)
else:
self.generator = generator
def _generate_indices(self):
n = len(self.data_source)
self._indices = torch.randperm(n, generator=self.generator).tolist()
def __iter__(self):
n = len(self.data_source)
if self._indices is None:
self._generate_indices()
start = self.current_iter % n
for i in range(start, n):
yield self._indices[i]
self.current_iter += 1
self.epoch += 1
self._indices = None
def __len__(self):
return len(self.data_source)
def state_dict(self):
return {
'epoch': self.epoch,
'current_iter': self.current_iter,
'seed': self.seed,
'generator_state': self.generator.get_state() if self.generator else None,
'indices': self._indices
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.current_iter = state_dict['current_iter']
self.seed = state_dict['seed']
if self.generator and state_dict['generator_state'] is not None:
self.generator.set_state(state_dict['generator_state'])
self._indices = state_dict['indices']

View File

@ -1,396 +0,0 @@
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import Dataset
from typing import Any, Literal, Tuple, Callable, Dict
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
def get_logprobs(model:nn.Module, input_ids: Tensor, mask: Tensor, pad_token_id: int):
input_mask = input_ids.ne(pad_token_id)
logits = model(input_ids, input_mask)["logits"]
log_probs = torch.log_softmax(logits, dim=-1)
shifted_log_probs = log_probs[:, :-1, :]
shifted_input_ids = input_ids[:, 1:]
shifted_response_mask = mask[:, 1:]
token_logprobs = torch.gather(
shifted_log_probs,
dim=-1,
index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)
prompt_mask = input_mask[:, 1:]
valid_mask = (prompt_mask & shifted_response_mask).float()
return (token_logprobs * valid_mask).sum(dim=-1)
class BaseStrategy(ABC):
def __init__(self, model: nn.Module):
self.model = model
@abstractmethod
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
raise NotImplementedError
def __call__(self, batch: Tuple[Tensor, ...]) -> Tensor:
return self.compute_loss(batch)
class SeqStrategy(BaseStrategy):
def __init__(self, model):
super().__init__(model)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
B, L = input_ids.size()
logits: Tensor = self.model(input_ids=input_ids)["logits"]
loss = F.cross_entropy(
input=logits.view(B * L, -1),
target=target_ids.flatten()
)
return loss
class SftStrategy(BaseStrategy):
def __init__(self, model: nn.Module):
super().__init__(model)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
input_ids, target_ids = batch["input_ids"], batch["target_ids"]
loss_mask, attn_mask = batch["loss_mask"], batch["attn_mask"]
ignore_index = -100
B, L = input_ids.size()
logits: Tensor = self.model(
input_ids=input_ids,
input_mask=attn_mask
)["logits"]
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
loss = F.cross_entropy(
input=logits.view(B * L, -1),
target=target_ids.flatten(),
ignore_index=ignore_index
)
return loss
class DpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, beta):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.beta = beta
def compute_loss(self, batch: Tuple[Tensor, ...]) -> Tensor:
good_ids, bad_ids = batch["chosen"], batch["rejected"]
good_mask, bad_mask = batch["chosen_mask"], batch["rejected_mask"]
log_pi_good = get_logprobs(self.model, good_ids, good_mask, self.pad_token_id)
log_pi_bad = get_logprobs(self.model, bad_ids, bad_mask, self.pad_token_id)
with torch.no_grad():
log_ref_good = get_logprobs(self.ref_model, good_ids, good_mask, self.pad_token_id)
log_ref_bad = get_logprobs(self.ref_model, bad_ids, bad_mask, self.pad_token_id)
pi_log_ratio = log_pi_good - log_pi_bad
ref_log_ratio = log_ref_good - log_ref_bad
ratio_diff = pi_log_ratio - ref_log_ratio
dpo_loss = -F.logsigmoid(self.beta * ratio_diff).mean()
return dpo_loss
class PpoStrategy(BaseStrategy):
def __init__(self, model, pad_token_id, epsilon):
super().__init__(model)
ref_model = copy.deepcopy(self.model)
ref_model.requires_grad_(False)
ref_model.eval()
self.ref_model = ref_model
self.pad_token_id = pad_token_id
self.epsilon = epsilon
def ppo_clip_loss_masked(
self,
log_probs: Tensor,
old_log_probs: Tensor,
advantages: Tensor,
values: Tensor,
returns: Tensor,
mask: Tensor,
clip_eps: float=0.2,
):
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
policy_loss = -torch.min(surr1, surr2).masked_select(mask).mean()
value_loss = F.mse_loss(values.masked_select(mask),
returns.masked_select(mask))
entropy = -(log_probs.exp() * log_probs).masked_select(mask).mean()
entropy_loss = -entropy
return policy_loss, value_loss, entropy_loss
class StrategyFactory:
def load(model, train_type, **kwargs):
train_strategy: Dict[str, Callable[[], BaseStrategy]] = {
"seq": lambda: SeqStrategy(model),
"sft": lambda: SftStrategy(model),
"dpo": lambda: DpoStrategy(
model,
kwargs.get("pad_token_id"),
kwargs.get("dpo_beta")
)
}
strategy = train_strategy[train_type]()
return strategy
@dataclass
class TrainConfig:
strategy: BaseStrategy = field(
default=None,
metadata={"help": "Training strategy."}
)
dataset: Dataset = field(
default=None,
metadata={"help": "Dataset for training."}
)
optimizer: Optimizer = field(
default=None,
metadata={"help": "Optimizer for training."}
)
checkpoint_dir: str = field(
default="./checkpoint",
metadata={"help": "Checkpoint directory."}
)
n_epoch: int = field(
default=1,
metadata={"help": "Number of epochs for training."}
)
batch_size: int = field(
default=4,
metadata={"help": "Batch size for training."}
)
checkpoint_interval: int = field(
default=5000,
metadata={"help": "Number of iterations between checkpoints."}
)
accumulation_steps: int = field(
default=1,
metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0,
metadata={"help": "Maximum gradient norm."}
)
random_seed: int = field(
default=3407,
metadata={"help": "Random seed."}
)
def get_kwargs(self)-> Dict[str, Any]:
config_dict = asdict(self)
return {k: v for k, v in config_dict.items() if v is not None}
@dataclass
class ScheduleConfig(ABC):
schedule_type: str = field(
default="cosine",
metadata={
"help": "Type of learning rate schedule.",
"choices": ["cosine", "sgdr"]
}
)
warmup_steps: int = field(
default=1000,
metadata={"help": "Number of warmup steps."}
)
min_rate: float = field(
default=0.05,
metadata={"help": "Minimum learning rate multiplier."}
)
@abstractmethod
def get_kwargs(self) -> Dict[str, Any]:
raise NotImplementedError
def validate(self) -> None:
"""Validate configuration parameters."""
if self.warmup_steps < 0:
raise ValueError(f"warmup_steps must be non-negative, got {self.warmup_steps}")
if not 0 <= self.min_rate <= 1:
raise ValueError(f"min_rate must be between 0 and 1, got {self.min_rate}")
@dataclass
class CosineScheduleConfig(ScheduleConfig):
total_steps: int = field(
default=None,
metadata={"help": "Total training steps for cosine schedule."}
)
schedule_type: Literal["cosine"] = "cosine"
def get_kwargs(self) -> Dict[str, Any]:
if self.total_steps is None:
raise ValueError("total_steps must be specified for cosine schedule")
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"lr_decay_steps": self.total_steps - self.warmup_steps,
"min_rate": self.min_rate
}
def validate(self) -> None:
super().validate()
if self.total_steps is not None and self.total_steps <= self.warmup_steps:
raise ValueError(f"total_steps ({self.total_steps}) must be greater than warmup_steps ({self.warmup_steps})")
@dataclass
class SgdrScheduleConfig(ScheduleConfig):
cycle_length: int = field(
default=1000,
metadata={"help": "Length of the first cycle in steps."}
)
t_mult: int = field(
default=2,
metadata={"help": "Multiplier for cycle length growth."}
)
schedule_type: Literal["sgdr"] = "sgdr"
def get_kwargs(self) -> Dict[str, Any]:
return {
"schedule_type": self.schedule_type,
"warmup_steps": self.warmup_steps,
"cycle_length": self.cycle_length,
"min_rate": self.min_rate,
"t_mult": self.t_mult
}
def validate(self) -> None:
super().validate()
if self.cycle_length <= 0:
raise ValueError(f"cycle_length must be positive, got {self.cycle_length}")
if self.t_mult < 1:
raise ValueError(f"t_mult must be >= 1, got {self.t_mult}")
class SchedulerFactory:
"""Factory for creating learning rate schedule functions."""
@staticmethod
def get_sgdr_schedule(
warmup_steps: int,
cycle_length: int,
min_rate: float = 0.05,
t_mult: int = 2
) -> Callable[[int], float]:
"""
Create SGDR (Stochastic Gradient Descent with Warm Restarts) schedule.
Args:
warmup_steps: Number of warmup steps
cycle_length: Length of the first cycle
min_rate: Minimum learning rate multiplier
t_mult: Cycle length multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def sgdr_schedule(current_step: int) -> float:
# Warmup phase
if current_step < warmup_steps:
return max(min_rate, current_step / warmup_steps)
# SGDR phase
steps_since_warmup = current_step - warmup_steps
# Find current cycle and position within cycle
cycle_start = 0
current_cycle_length = cycle_length
cycle_index = 0
while steps_since_warmup >= cycle_start + current_cycle_length:
cycle_start += current_cycle_length
current_cycle_length *= t_mult
cycle_index += 1
position_in_cycle = steps_since_warmup - cycle_start
progress = position_in_cycle / current_cycle_length
# Cosine annealing within cycle
return max(min_rate, 0.5 * (1 + math.cos(math.pi * progress)))
return sgdr_schedule
@staticmethod
def get_cosine_schedule(
warmup_steps: int,
lr_decay_steps: int,
min_rate: float = 0.05
) -> Callable[[int], float]:
"""
Create cosine decay schedule with warmup.
Args:
warmup_steps: Number of warmup steps
lr_decay_steps: Number of steps for cosine decay after warmup
min_rate: Minimum learning rate multiplier
Returns:
Schedule function that takes current step and returns LR multiplier
"""
def cosine_schedule(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup
return max(min_rate, current_step / warmup_steps)
else:
# Cosine decay
decay_progress = (current_step - warmup_steps) / lr_decay_steps
decay_progress = min(decay_progress, 1.0) # Clamp at 1.0
return max(min_rate, 0.5 * (1.0 + math.cos(math.pi * decay_progress)))
return cosine_schedule
@staticmethod
def load_schedule_fn(scedule_config: ScheduleConfig) -> Callable[[int], float]:
kwargs = scedule_config.get_kwargs()
schedule_type = kwargs.pop("schedule_type")
if schedule_type == "cosine":
return SchedulerFactory.get_cosine_schedule(**kwargs)
elif schedule_type == "sgdr":
return SchedulerFactory.get_sgdr_schedule(**kwargs)
else:
raise ValueError(f"Unsupported schedule type: {schedule_type}")

View File

@ -1,140 +0,0 @@
import logging
from typing import Optional, List, cast
from torch.utils.data import DataLoader
from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
from khaosz.trainer.trainer_callback import (
TrainerCallback,
ProgressBarCallback,
CheckpointCallback,
GradientClippingCallback,
SchedulerCallback
)
logger = logging.getLogger(__name__)
class Trainer:
def __init__(
self,
parameter: ModelParameter,
train_config: TrainConfig,
schedule_config: ScheduleConfig,
callbacks: Optional[List[TrainerCallback]] = None
):
self.parameter = parameter
self.train_config = train_config
self.schedule_config = schedule_config
self.callbacks = callbacks or self._get_default_callbacks()
def _get_default_callbacks(self) -> List[TrainerCallback]:
return [
ProgressBarCallback(),
CheckpointCallback(self.train_config.checkpoint_interval),
GradientClippingCallback(),
SchedulerCallback(self.schedule_config),
]
def _set_train_kwargs(self, kwargs: dict):
seed = self.train_config.random_seed
sampler = RandomSampler(data_source=self.train_config.dataset, seed=seed)
optim = self.train_config.optimizer
checkpoint = cast(Checkpoint, kwargs.get('checkpoint', None))
if checkpoint is None:
checkpoint = Checkpoint(
model=self.parameter.model,
tokenizer=self.parameter.tokenizer,
config=self.parameter.config,
sampler_state=None,
optim_state=None,
loss_list=[]
)
sampler_state = checkpoint.sampler_state
optim_state = checkpoint.optim_state
if sampler_state:
sampler.load_state_dict(sampler_state)
if optim_state:
optim.load_state_dict(optim_state)
checkpoint.optim_state = optim.state_dict()
checkpoint.sampler_state = sampler.state_dict()
dataloader = DataLoader(
self.train_config.dataset,
batch_size=self.train_config.batch_size,
sampler=sampler
)
kwargs["dataloader"] = dataloader
kwargs["optimizer"] = self.train_config.optimizer
kwargs["epoch"] = sampler.epoch
kwargs["current_iter"] = sampler.current_iter
kwargs["sampler"] = sampler
kwargs["checkpoint"] = checkpoint
def _call_callbacks(self, method_name: str, **kwargs):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(self, **kwargs)
def train(
self,
checkpoint: Optional[Checkpoint] = None
) -> Checkpoint:
# train
train_kwargs = {
'checkpoint': checkpoint,
'dataloader': None,
'optimizer': None,
'sampler': None,
'epoch': 0,
'current_iter': 0,
'loss': 0.0,
}
self._set_train_kwargs(train_kwargs)
self._call_callbacks('on_train_begin', **train_kwargs)
dataloader = train_kwargs['dataloader']
checkpoint = train_kwargs['checkpoint']
start_epoch = train_kwargs['epoch']
try:
self.parameter.model.train()
for epoch in range(start_epoch, self.train_config.n_epoch):
# epoch
train_kwargs["epoch"] = epoch
self._call_callbacks('on_epoch_begin', **train_kwargs)
for batch in dataloader:
if train_kwargs["current_iter"] % self.train_config.accumulation_steps == 0:
# step
self._call_callbacks('on_step_begin', **train_kwargs)
self.train_config.optimizer.step()
self.train_config.optimizer.zero_grad()
self._call_callbacks('on_step_end', **train_kwargs)
# batch
self._call_callbacks('on_batch_begin', **train_kwargs)
loss = self.train_config.strategy(batch)
train_kwargs["loss"] = loss.item()
train_kwargs["current_iter"] += 1
loss.backward()
self._call_callbacks('on_batch_end', **train_kwargs)
self._call_callbacks('on_epoch_end', **train_kwargs)
except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True)
raise
finally:
self._call_callbacks('on_train_end', **train_kwargs)
return checkpoint

View File

@ -1,169 +0,0 @@
import os
import torch.optim as optim
from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import LambdaLR
from typing import Optional, cast, TYPE_CHECKING
from khaosz.core.parameter import Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.trainer.strategy import ScheduleConfig, SchedulerFactory
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
class TrainerCallback:
"""
Callback interface for trainer.
and we use '_' to ignore unused parameters.
"""
def on_train_begin(self, trainer: 'Trainer', **kwargs):
""" Called at the beginning of training. """
_ = trainer, kwargs
def on_train_end(self, trainer: 'Trainer', **kwargs):
""" Called at the end of training. """
_ = trainer, kwargs
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
""" Called at the beginning of each epoch. """
_ = trainer, kwargs
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
""" Called at the end of each epoch. """
_ = trainer, kwargs
def on_batch_begin(self, trainer: 'Trainer', **kwargs):
""" Called at the beginning of each batch. """
_ = trainer, kwargs
def on_batch_end(self, trainer: 'Trainer', **kwargs):
""" Called at the end of each batch. """
_ = trainer, kwargs
def on_step_begin(self, trainer: 'Trainer', **kwargs):
""" Called at the beginning of each step. """
_ = trainer, kwargs
def on_step_end(self, trainer: 'Trainer', **kwargs):
""" Called at the end of each step."""
_ = trainer, kwargs
class ProgressBarCallback(TrainerCallback):
"""
Progress bar callback for trainer.
"""
def __init__(self):
self.progress_bar: tqdm = None
def on_epoch_begin(self, trainer: 'Trainer', **kwargs):
epoch = kwargs.get('epoch')
dataloader = kwargs.get('dataloader')
self.progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch+1}/{trainer.train_config.n_epoch}",
dynamic_ncols=True
)
def on_batch_end(self, trainer: 'Trainer', **kwargs):
_ = trainer
loss = kwargs.get('loss')
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
self.progress_bar.set_postfix({
"loss": f"{loss:.4f}",
"lr": f"{optimizer.param_groups[-1]['lr']:.2e}"
})
self.progress_bar.update(1)
def on_epoch_end(self, trainer: 'Trainer', **kwargs):
_ = trainer, kwargs
if self.progress_bar:
self.progress_bar.close()
class CheckpointCallback(TrainerCallback):
"""
Checkpoint callback for trainer.
"""
def __init__(self, checkpoint_interval: int):
self.checkpoint_interval = checkpoint_interval
self.last_ckpt_iter = 0
@staticmethod
def _save_checkpoint(trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter')
random_sampler = cast(RandomSampler, kwargs.get('sampler'))
optimizer = cast(optim.Optimizer, kwargs.get('optimizer'))
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{current_iter}")
checkpoint.sampler_state = random_sampler.state_dict()
checkpoint.optim_state = optimizer.state_dict()
checkpoint.save(save_path)
def on_batch_end(self, trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter')
checkpoint = cast(Checkpoint, kwargs.get('checkpoint'))
loss = kwargs.get('loss')
checkpoint.loss_list.append(loss)
if current_iter - self.last_ckpt_iter >= self.checkpoint_interval:
CheckpointCallback._save_checkpoint(trainer, **kwargs)
self.last_ckpt_iter = current_iter
def on_train_end(self, trainer: 'Trainer', **kwargs):
current_iter = kwargs.get('current_iter')
if current_iter != self.last_ckpt_iter:
CheckpointCallback._save_checkpoint(trainer, **kwargs)
self.last_ckpt_iter = current_iter
class GradientClippingCallback(TrainerCallback):
"""
Gradient clipping callback for trainer.
"""
def on_step_begin(self, trainer: 'Trainer', **kwargs):
_ = kwargs
clip_grad_norm_(
trainer.parameter.model.parameters(),
trainer.train_config.max_grad_norm
)
class SchedulerCallback(TrainerCallback):
"""
Scheduler callback for trainer.
"""
def __init__(self, schedule_config: ScheduleConfig):
self.schedule_config = schedule_config
self.scheduler: Optional[LambdaLR] = None
self.current_iter = 0
def on_train_begin(self, trainer: 'Trainer', **kwargs):
self.current_iter = kwargs.get('current_iter')
for group in trainer.train_config.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
self.schedule_config.validate()
lambda_scheduler_fn = SchedulerFactory.load_schedule_fn(
self.schedule_config
)
self.scheduler = LambdaLR(
trainer.train_config.optimizer,
lambda_scheduler_fn,
last_epoch=self.current_iter - 1
)
def on_batch_end(self, trainer: 'Trainer', **kwargs):
_ = trainer, kwargs
if self.scheduler:
self.scheduler.step()
self.current_iter += 1

View File

@ -1,88 +0,0 @@
import torch
import sqlite3
import numpy as np
from torch import Tensor
from typing import Dict, List, Tuple
class Retriever:
def __init__(self, db_path=None):
self.data: Dict[str, Tensor] = {}
self.embedding_cache: Tensor = None
self.is_caculated: bool = False
if db_path is not None:
self.load(db_path)
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
if not self.data:
return []
query = query.flatten().unsqueeze(1) # [dim, 1]
norm_embeddings = self._embeddings.to(
device=query.device,
dtype=query.dtype
) # [n_vectors, dim]
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
top_k = min(top_k, len(self.data))
indices = sim_scores.topk(top_k).indices
keys = list(self.data.keys())
return [(keys[i], sim_scores[i].item()) for i in indices]
def add_vector(self, key: str, vector_data: Tensor):
self.is_caculated = False
self.data[key] = vector_data.flatten().float().cpu()
def delete_vector(self, key: str):
self.is_caculated = False
self.data.pop(key, None)
def save(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('DELETE FROM vectors')
for item, vec in self.data.items():
vec_bytes = vec.numpy().tobytes()
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
(item, vec_bytes))
conn.commit()
conn.close()
def load(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('SELECT key, vector FROM vectors')
rows = cursor.fetchall()
self.data = {}
for row in rows:
key, vec_bytes = row
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
vec = torch.from_numpy(vec_numpy)
self.data[key] = vec
conn.close()
def _init_db(self,cursor: sqlite3.Cursor):
# Create table if not exists (in case loading from a new database)
cursor.execute('''
CREATE TABLE IF NOT EXISTS vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
vector BLOB NOT NULL
)''')
@property
def _embeddings(self) -> Tensor:
if not self.is_caculated:
embeddings = torch.stack(list(self.data.values()))
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
self.embedding_cache = norm_embeddings
return self.embedding_cache

Some files were not shown because too many files have changed in this diff Show More