Compare commits
No commits in common. "44dab27fdc364dfc1f101e1641009c0dcab0f00d" and "ad9f4d9cf60f35cf742509b8096c7b541252c5be" have entirely different histories.
44dab27fdc
...
ad9f4d9cf6
|
|
@ -67,4 +67,4 @@ class ModelConfig(BaseModelConfig):
|
||||||
n_routed_experts: Optional[int] = None
|
n_routed_experts: Optional[int] = None
|
||||||
n_shared_experts: Optional[int] = None
|
n_shared_experts: Optional[int] = None
|
||||||
n_activated_experts: Optional[int] = None
|
n_activated_experts: Optional[int] = None
|
||||||
topk_method: Optional[str] = None
|
moe_topk_method: Optional[str] = None
|
||||||
|
|
|
||||||
|
|
@ -28,26 +28,6 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.storage: Optional[BaseStorage] = None
|
self.storage: Optional[BaseStorage] = None
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
"""Return required storage keys for this dataset type.
|
|
||||||
|
|
||||||
Subclasses should override to specify expected keys.
|
|
||||||
"""
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _validate_keys(self):
|
|
||||||
if not self.required_keys:
|
|
||||||
return
|
|
||||||
actual_keys = set(self.storage.keys)
|
|
||||||
missing = [k for k in self.required_keys if k not in actual_keys]
|
|
||||||
if missing:
|
|
||||||
raise KeyError(
|
|
||||||
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
|
|
||||||
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
|
|
||||||
f"Missing: {missing}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
||||||
"""Load dataset from the given path.
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
|
|
@ -59,16 +39,11 @@ class BaseDataset(Dataset, ABC):
|
||||||
or None for auto-detection
|
or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
tokenizer: Callable str -> List[int], used to tokenize raw text
|
||||||
in JSON files. Ignored for HDF5.
|
in JSON files. Ignored for HDF5.
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If the loaded storage is missing required keys.
|
|
||||||
"""
|
"""
|
||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = StorageFactory.create(storage_type)
|
self.storage = StorageFactory.create(storage_type)
|
||||||
self._load_path = load_path
|
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self.storage.load(load_path, tokenizer=tokenizer)
|
||||||
self._validate_keys()
|
|
||||||
|
|
||||||
def load_json(self, load_path: str, tokenizer=None):
|
def load_json(self, load_path: str, tokenizer=None):
|
||||||
"""Load dataset from JSON files explicitly.
|
"""Load dataset from JSON files explicitly.
|
||||||
|
|
@ -211,10 +186,6 @@ class SEQDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["sequence"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
|
|
@ -234,10 +205,6 @@ class SFTDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["sequence", "loss_mask"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
|
@ -262,10 +229,6 @@ class DPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
|
@ -296,10 +259,6 @@ class GRPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
@property
|
|
||||||
def required_keys(self) -> List[str]:
|
|
||||||
return ["prompts", "responses", "masks", "rewards"]
|
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,6 @@ class MLA(nn.Module):
|
||||||
qk_nope_head_dim: int,
|
qk_nope_head_dim: int,
|
||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_qk_norm: bool,
|
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
):
|
):
|
||||||
|
|
@ -134,14 +133,9 @@ class MLA(nn.Module):
|
||||||
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.n_rep = n_heads // n_kv_heads
|
self.n_rep = n_heads // n_kv_heads
|
||||||
self.use_qk_norm = use_qk_norm
|
|
||||||
self.use_gated_attention = use_gated_attention
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
|
||||||
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
|
|
@ -188,10 +182,6 @@ class MLA(nn.Module):
|
||||||
q = torch.cat([q_nope, q_rope], dim=-1)
|
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||||
|
|
||||||
if self.use_qk_norm:
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
if paged_cache is not None:
|
if paged_cache is not None:
|
||||||
paged_cache.write(self.layer_id, k, v)
|
paged_cache.write(self.layer_id, k, v)
|
||||||
k, v = paged_cache.gather(self.layer_id)
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,5 @@ class Embedding(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.embedding(x, self.weight)
|
return F.embedding(x, self.weight)
|
||||||
|
|
|
||||||
|
|
@ -10,12 +10,5 @@ class Linear(nn.Module):
|
||||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
|
||||||
if self.bias is not None:
|
|
||||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
||||||
bound = 1 / (fan_in**0.5)
|
|
||||||
nn.init.uniform_(self.bias, -bound, bound)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.linear(x, self.weight, self.bias)
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ class Transformer(AutoModel):
|
||||||
n_routed_experts=config.n_routed_experts,
|
n_routed_experts=config.n_routed_experts,
|
||||||
n_shared_experts=config.n_shared_experts,
|
n_shared_experts=config.n_shared_experts,
|
||||||
n_activated_experts=config.n_activated_experts,
|
n_activated_experts=config.n_activated_experts,
|
||||||
topk_method=config.topk_method,
|
topk_method=config.moe_topk_method,
|
||||||
kv_lora_rank=config.kv_lora_rank,
|
kv_lora_rank=config.kv_lora_rank,
|
||||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||||
|
|
@ -93,11 +93,12 @@ class Transformer(AutoModel):
|
||||||
if self.config.tie_weight is True:
|
if self.config.tie_weight is True:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
self._init_weights()
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self):
|
||||||
if hasattr(module, "reset_parameters"):
|
for param in self.parameters():
|
||||||
module.reset_parameters()
|
if param.dim() > 1:
|
||||||
|
nn.init.normal_(param, mean=0.0, std=0.006)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
lm_head_key = "lm_head.weight"
|
lm_head_key = "lm_head.weight"
|
||||||
|
|
|
||||||
|
|
@ -26,13 +26,8 @@ class Trainer:
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
return [
|
return [
|
||||||
CallbackFactory.create(
|
|
||||||
"checkpoint",
|
|
||||||
cfg.ckpt_dir,
|
|
||||||
cfg.ckpt_interval,
|
|
||||||
state_dict_fn=cfg.state_dict_fn,
|
|
||||||
),
|
|
||||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||||
|
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -180,9 +180,7 @@ def create_scheduler(
|
||||||
|
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||||
if isinstance(model, DDP):
|
return model.module.state_dict()
|
||||||
return model.module.state_dict()
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
|
|
||||||
def compute_total_steps(
|
def compute_total_steps(
|
||||||
|
|
@ -255,7 +253,7 @@ def train(
|
||||||
model = model.to(dtype=torch.bfloat16)
|
model = model.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
"label_smoothing": label_smoothing,
|
"label_smoothing": label_smoothing,
|
||||||
"clip_eps": grpo_clip_eps,
|
"clip_eps": grpo_clip_eps,
|
||||||
"kl_coef": grpo_kl_coef,
|
"kl_coef": grpo_kl_coef,
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ CONFIGS = [
|
||||||
"n_routed_experts": 4,
|
"n_routed_experts": 4,
|
||||||
"n_shared_experts": 1,
|
"n_shared_experts": 1,
|
||||||
"n_activated_experts": 2,
|
"n_activated_experts": 2,
|
||||||
"topk_method": "greedy",
|
"moe_topk_method": "greedy",
|
||||||
},
|
},
|
||||||
id="gqa_moe",
|
id="gqa_moe",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue