Compare commits

..

4 Commits

Author SHA1 Message Date
ViperEkura 44dab27fdc feat: 数据集加载时校验必填字段
- BaseDataset.required_keys 属性声明所需存储 key
- load() 时自动校验,缺失立即抛 KeyError
- SEQ/SFT/DPO/GRPO 各自声明 required_keys
2026-05-17 11:50:38 +08:00
ViperEkura a44fd22a99 fix: 修复训练与模型参数传递问题
- state_dict_fn 传入 CheckpointCallback,修复多卡 DDP 下 key 前缀丢失
- MLA 增加 use_qk_norm 支持,消除参数静默丢失
- moe_topk_method 统一命名为 topk_method
- checkpoint 回调移至最前
2026-05-17 11:20:13 +08:00
ViperEkura 8a11a7d444 fix: 修复训练脚本两处参数传递问题
- prepare_checkpoint 增加 DDP 判断,单卡时不访问 .module
- dpo_beta 改为 beta,对齐 DPOStrategy 参数名
2026-05-17 11:04:40 +08:00
ViperEkura 1d54491809 refactor: 改用递归子模块 init 替代统一 normal_(0.006)
- Embedding.reset_parameters: normal_(std=0.02)
- Linear.reset_parameters: kaiming_uniform_ + uniform_ bias
- Transformer._init_weights 通过 apply 递归调用子模块 reset_parameters
- 移除全局 normal_(0.006) 覆盖,各模块使用更合适的分布
2026-05-17 10:44:18 +08:00
9 changed files with 78 additions and 11 deletions

View File

@ -67,4 +67,4 @@ class ModelConfig(BaseModelConfig):
n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
moe_topk_method: Optional[str] = None
topk_method: Optional[str] = None

View File

@ -28,6 +28,26 @@ class BaseDataset(Dataset, ABC):
self.stride = stride
self.storage: Optional[BaseStorage] = None
@property
def required_keys(self) -> List[str]:
"""Return required storage keys for this dataset type.
Subclasses should override to specify expected keys.
"""
return []
def _validate_keys(self):
if not self.required_keys:
return
actual_keys = set(self.storage.keys)
missing = [k for k in self.required_keys if k not in actual_keys]
if missing:
raise KeyError(
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
f"Missing: {missing}"
)
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
"""Load dataset from the given path.
@ -39,11 +59,16 @@ class BaseDataset(Dataset, ABC):
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
Raises:
KeyError: If the loaded storage is missing required keys.
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys()
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
@ -186,6 +211,10 @@ class SEQDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence"]
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, "sequence")
@ -205,6 +234,10 @@ class SFTDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence", "loss_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -229,6 +262,10 @@ class DPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -259,6 +296,10 @@ class GRPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["prompts", "responses", "masks", "rewards"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)

View File

@ -120,6 +120,7 @@ class MLA(nn.Module):
qk_nope_head_dim: int,
qk_rope_head_dim: int,
norm_eps: float,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
@ -133,9 +134,14 @@ class MLA(nn.Module):
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
self.layer_id = layer_id
self.n_rep = n_heads // n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attention = use_gated_attention
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, norm_eps)
self.k_norm = RMSNorm(self.head_dim, norm_eps)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
@ -182,6 +188,10 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)

View File

@ -9,5 +9,8 @@ class Embedding(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -10,5 +10,12 @@ class Linear(nn.Module):
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / (fan_in**0.5)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -78,7 +78,7 @@ class Transformer(AutoModel):
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method,
topk_method=config.topk_method,
kv_lora_rank=config.kv_lora_rank,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
@ -93,12 +93,11 @@ class Transformer(AutoModel):
if self.config.tie_weight is True:
self.lm_head.weight = self.embed_tokens.weight
self._init_weights()
self.apply(self._init_weights)
def _init_weights(self):
for param in self.parameters():
if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = "lm_head.weight"

View File

@ -26,8 +26,13 @@ class Trainer:
def _get_default_callbacks(self) -> List[TrainCallback]:
cfg = self.train_config
return [
CallbackFactory.create(
"checkpoint",
cfg.ckpt_dir,
cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn,
),
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
]

View File

@ -180,7 +180,9 @@ def create_scheduler(
def prepare_checkpoint(model: nn.Module) -> dict:
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
def compute_total_steps(
@ -253,7 +255,7 @@ def train(
model = model.to(dtype=torch.bfloat16)
strategy_kwargs = {
"dpo_beta": dpo_beta,
"beta": dpo_beta,
"label_smoothing": label_smoothing,
"clip_eps": grpo_clip_eps,
"kl_coef": grpo_kl_coef,

View File

@ -40,7 +40,7 @@ CONFIGS = [
"n_routed_experts": 4,
"n_shared_experts": 1,
"n_activated_experts": 2,
"moe_topk_method": "greedy",
"topk_method": "greedy",
},
id="gqa_moe",
),