refactor: 检查点加载重构,路径替代对象传递
- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串 - Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递 - TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch - CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁 - serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast - optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行 - pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
This commit is contained in:
parent
34c6c45bd6
commit
4145d35e3c
|
|
@ -17,8 +17,8 @@ def required(**kw):
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig(BaseConfig):
|
class TrainConfig(BaseConfig):
|
||||||
# basic setting
|
# basic setting
|
||||||
model: nn.Module = field(
|
model_fn: Callable[[], nn.Module] = field(
|
||||||
default=None, metadata=required(help="Model for training.")
|
default=None, metadata=required(help="Model factory for training.")
|
||||||
)
|
)
|
||||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||||
dataset: Dataset = field(
|
dataset: Dataset = field(
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -11,8 +12,8 @@ import torch.distributed as dist
|
||||||
from astrai.parallel.setup import get_rank
|
from astrai.parallel.setup import get_rank
|
||||||
|
|
||||||
_META_FILE = "meta.json"
|
_META_FILE = "meta.json"
|
||||||
|
_CONFIG_FILE = "config.json"
|
||||||
_WEIGHTS_FILE = "model.safetensors"
|
_WEIGHTS_FILE = "model.safetensors"
|
||||||
_MODEL_CONFIG_FILE = "config.json"
|
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
||||||
|
|
@ -37,9 +38,88 @@ def save_torch(obj: Any, path: str | Path) -> None:
|
||||||
torch.save(obj, str(path))
|
torch.save(obj, str(path))
|
||||||
|
|
||||||
|
|
||||||
def load_torch(path: str | Path) -> Any:
|
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
rank = get_rank()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
raw = f.read()
|
||||||
|
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
||||||
|
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
||||||
|
else:
|
||||||
|
num_bytes = torch.tensor([0], dtype=torch.long)
|
||||||
|
|
||||||
|
dist.broadcast(num_bytes, src=0)
|
||||||
|
|
||||||
|
if rank != 0:
|
||||||
|
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
||||||
|
|
||||||
|
dist.broadcast(data_tensor, src=0)
|
||||||
|
|
||||||
|
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
||||||
|
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
||||||
|
save_path = Path(save_directory)
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_json(config, save_path / _CONFIG_FILE)
|
||||||
|
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_config(save_directory: str) -> dict:
|
||||||
|
return load_json(Path(save_directory) / _CONFIG_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(save_directory: str) -> dict:
|
||||||
|
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_meta(save_path: Path) -> dict:
|
||||||
|
meta = {}
|
||||||
|
if get_rank() == 0:
|
||||||
|
meta = load_json(save_path / _META_FILE)
|
||||||
|
if dist.is_initialized():
|
||||||
|
meta_list = [meta]
|
||||||
|
dist.broadcast_object_list(meta_list, src=0)
|
||||||
|
meta = meta_list[0]
|
||||||
|
return meta
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
|
return load_safetensors(save_path / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||||
|
specs: List[Tuple[str, List[int], str]] = [
|
||||||
|
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||||
|
for k in sorted(state_dict)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
state_dict = {}
|
||||||
|
specs = []
|
||||||
|
|
||||||
|
specs_list = [specs]
|
||||||
|
dist.broadcast_object_list(specs_list, src=0)
|
||||||
|
specs = specs_list[0]
|
||||||
|
|
||||||
|
for key, shape, dtype_name in specs:
|
||||||
|
dtype = getattr(torch, dtype_name)
|
||||||
|
if rank != 0:
|
||||||
|
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
||||||
|
else:
|
||||||
|
tensor = state_dict[key].contiguous().cpu()
|
||||||
|
dist.broadcast(tensor, src=0)
|
||||||
|
if rank != 0:
|
||||||
|
state_dict[key] = tensor
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
|
|
@ -68,24 +148,16 @@ class Checkpoint:
|
||||||
save_torch(value, save_path / f"{key}.pt")
|
save_torch(value, save_path / f"{key}.pt")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, save_dir: str) -> "Checkpoint":
|
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = {}
|
meta = _get_meta(save_path)
|
||||||
if get_rank() == 0:
|
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
||||||
meta = load_json(save_path / _META_FILE)
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
|
||||||
meta_list = [meta]
|
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
|
||||||
meta = meta_list[0]
|
|
||||||
|
|
||||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
|
||||||
|
|
||||||
extra = {}
|
extra = {}
|
||||||
for f in save_path.iterdir():
|
for f in sorted(save_path.iterdir()):
|
||||||
if f.suffix == ".pt":
|
if f.suffix == ".pt":
|
||||||
extra[f.stem] = load_torch(f)
|
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
|
@ -93,18 +165,3 @@ class Checkpoint:
|
||||||
iteration=meta.get("iteration", 0),
|
iteration=meta.get("iteration", 0),
|
||||||
extra=extra,
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
|
||||||
save_path = Path(save_directory)
|
|
||||||
save_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_json(config, save_path / _MODEL_CONFIG_FILE)
|
|
||||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(save_directory: str) -> dict:
|
|
||||||
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(save_directory: str) -> dict:
|
|
||||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.parallel import only_on_rank
|
from astrai.parallel import only_on_rank
|
||||||
from astrai.parallel.setup import get_current_device
|
from astrai.parallel.setup import get_current_device, get_rank
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.metric_util import (
|
from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_max,
|
ctx_get_grad_max,
|
||||||
|
|
@ -139,27 +139,27 @@ class CheckpointCallback(TrainCallback):
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
self.state_dict_fn = state_dict_fn
|
||||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||||
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
save_path = os.path.join(
|
# All ranks gather state_dict — collective for FSDP, local for DDP
|
||||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
|
||||||
)
|
|
||||||
state_dict = (
|
state_dict = (
|
||||||
self.state_dict_fn(context.model)
|
self.state_dict_fn(context.model)
|
||||||
if self.state_dict_fn
|
if self.state_dict_fn
|
||||||
else context.model.state_dict()
|
else context.model.state_dict()
|
||||||
)
|
)
|
||||||
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
|
if get_rank() == 0:
|
||||||
|
save_path = os.path.join(
|
||||||
|
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||||
|
)
|
||||||
extra = self.save_extra_fn(context)
|
extra = self.save_extra_fn(context)
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
|
|
@ -168,13 +168,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
extra=extra,
|
extra=extra,
|
||||||
meta=context.config.to_dict(),
|
meta=context.config.to_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
|
||||||
if context.checkpoint and context.checkpoint.extra:
|
|
||||||
self.load_extra_fn(context.checkpoint.extra, context)
|
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
|
|
@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback):
|
||||||
extra[name] = obj.state_dict()
|
extra[name] = obj.state_dict()
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_extra(extra: dict, context: TrainContext):
|
|
||||||
for name in CheckpointCallback.extra_keys:
|
|
||||||
if name in extra:
|
|
||||||
getattr(context, name).load_state_dict(extra[name])
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("progress_bar")
|
@CallbackFactory.register("progress_bar")
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora
|
||||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.serialization import Checkpoint, load_model_weights
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,10 +43,10 @@ class TrainContextBuilder:
|
||||||
config: TrainConfig,
|
config: TrainConfig,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._checkpoint: Optional[Checkpoint] = None
|
self._resume_dir: Optional[str] = None
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
||||||
self._checkpoint = checkpoint
|
self._resume_dir = resume_dir
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
|
|
@ -58,36 +59,40 @@ class TrainContextBuilder:
|
||||||
**cfg.executor_kwargs,
|
**cfg.executor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model = cfg.model_fn()
|
||||||
|
model = model.to(device=device)
|
||||||
|
|
||||||
context = TrainContext(
|
context = TrainContext(
|
||||||
model=cfg.model,
|
model=model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
config=cfg,
|
config=cfg,
|
||||||
executor=executor,
|
executor=executor,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.model = context.model.to(device=device)
|
if self._resume_dir is not None:
|
||||||
|
resume_path = Path(self._resume_dir)
|
||||||
if self._checkpoint is not None:
|
if (resume_path / "meta.json").exists():
|
||||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
checkpoint = Checkpoint.load(self._resume_dir)
|
||||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
state_dict = checkpoint.state_dict
|
||||||
if self._checkpoint.state_dict:
|
|
||||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
|
||||||
context.checkpoint = self._checkpoint
|
|
||||||
else:
|
else:
|
||||||
context.checkpoint = Checkpoint(
|
checkpoint = None
|
||||||
state_dict=context.model.state_dict(),
|
state_dict = load_model_weights(self._resume_dir)
|
||||||
)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
if checkpoint is not None:
|
||||||
|
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
|
||||||
|
context.iteration = max(checkpoint.iteration, cfg.start_batch)
|
||||||
|
context.checkpoint = checkpoint
|
||||||
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
inject_lora(
|
inject_lora(
|
||||||
context.model,
|
model,
|
||||||
r=cfg.lora.r,
|
r=cfg.lora.r,
|
||||||
alpha=cfg.lora.alpha,
|
alpha=cfg.lora.alpha,
|
||||||
target_modules=set(cfg.lora.target_modules),
|
target_modules=set(cfg.lora.target_modules),
|
||||||
)
|
)
|
||||||
|
|
||||||
context.optimizer = cfg.optimizer_fn(context.model)
|
context.optimizer = cfg.optimizer_fn(model)
|
||||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
sampler_offset = context.iteration * cfg.batch_per_device
|
sampler_offset = context.iteration * cfg.batch_per_device
|
||||||
|
|
@ -125,13 +130,21 @@ class TrainContextBuilder:
|
||||||
|
|
||||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||||
executor.prepare(
|
executor.prepare(
|
||||||
context.model,
|
model,
|
||||||
context.optimizer,
|
context.optimizer,
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
context.scheduler,
|
context.scheduler,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if context.checkpoint and context.checkpoint.extra:
|
||||||
|
extra = context.checkpoint.extra
|
||||||
|
for name in ("optimizer", "scheduler"):
|
||||||
|
if name in extra:
|
||||||
|
obj = getattr(context, name, None)
|
||||||
|
if obj is not None:
|
||||||
|
obj.load_state_dict(extra[name])
|
||||||
|
|
||||||
context.strategy = StrategyFactory.create(
|
context.strategy = StrategyFactory.create(
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=cfg.strategy,
|
train_type=cfg.strategy,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
from astrai.parallel.setup import spawn_parallel_fn
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
CallbackFactory,
|
CallbackFactory,
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
|
|
@ -54,9 +53,9 @@ class Trainer:
|
||||||
if method:
|
if method:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
||||||
context = (
|
context = (
|
||||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
||||||
)
|
)
|
||||||
executor = context.executor
|
executor = context.executor
|
||||||
self._call_callbacks("on_train_begin", context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
@ -90,13 +89,13 @@ class Trainer:
|
||||||
self._call_callbacks("on_epoch_end", context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
logger.error("Training failed: %s", str(e), exc_info=True)
|
||||||
self._call_callbacks("on_error", context)
|
self._call_callbacks("on_error", context)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks("on_train_end", context)
|
self._call_callbacks("on_train_end", context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
def train(self, resume_dir: Optional[str] = None):
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
spawn_parallel_fn(
|
spawn_parallel_fn(
|
||||||
self._trainer_loop,
|
self._trainer_loop,
|
||||||
|
|
@ -106,5 +105,5 @@ class Trainer:
|
||||||
master_port=cfg.master_port,
|
master_port=cfg.master_port,
|
||||||
device_type=cfg.device_type,
|
device_type=cfg.device_type,
|
||||||
start_method=cfg.start_method,
|
start_method=cfg.start_method,
|
||||||
checkpoint=checkpoint,
|
resume_dir=resume_dir,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import torch.optim as optim
|
||||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import AutoRegressiveLM
|
from astrai.model import AutoRegressiveLM
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(config):
|
||||||
|
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||||
|
|
||||||
|
|
@ -238,15 +241,6 @@ def train(
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = config.max_len
|
window_size = config.max_len
|
||||||
|
|
||||||
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
|
|
||||||
checkpoint = Checkpoint.load(param_path)
|
|
||||||
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
|
||||||
model.load_state_dict(checkpoint.state_dict, strict=False)
|
|
||||||
|
|
||||||
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
|
|
||||||
# (model weights already loaded into model above)
|
|
||||||
checkpoint.state_dict = {}
|
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"beta": dpo_beta,
|
"beta": dpo_beta,
|
||||||
"label_smoothing": label_smoothing,
|
"label_smoothing": label_smoothing,
|
||||||
|
|
@ -261,6 +255,7 @@ def train(
|
||||||
"broadcast_buffers": False,
|
"broadcast_buffers": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model_fn = partial(create_model, config)
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
|
|
@ -292,7 +287,7 @@ def train(
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=model,
|
model_fn=model_fn,
|
||||||
strategy=train_type,
|
strategy=train_type,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
|
|
@ -317,7 +312,7 @@ def train(
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train(checkpoint=checkpoint)
|
trainer.train(resume_dir=param_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
def create_train_config(
|
def create_train_config(
|
||||||
model: torch.nn.Module,
|
model_fn,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
test_dir: str,
|
test_dir: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
|
@ -43,7 +43,7 @@ def create_train_config(
|
||||||
"""Factory function to create common TrainConfig for tests.
|
"""Factory function to create common TrainConfig for tests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model to train
|
model_fn: Model factory (callable returning nn.Module)
|
||||||
dataset: Training dataset
|
dataset: Training dataset
|
||||||
test_dir: Checkpoint directory
|
test_dir: Checkpoint directory
|
||||||
device: Device type ("cuda" or "cpu")
|
device: Device type ("cuda" or "cpu")
|
||||||
|
|
@ -70,7 +70,7 @@ def create_train_config(
|
||||||
|
|
||||||
return TrainConfig(
|
return TrainConfig(
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
model=model,
|
model_fn=model_fn,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
|
|
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.trainer.schedule import SchedulerFactory
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
strategy="seq",
|
strategy="seq",
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||||
|
|
@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
||||||
# Should handle early stopping gracefully
|
# Should handle early stopping gracefully
|
||||||
checkpoint = None
|
|
||||||
try:
|
try:
|
||||||
checkpoint = trainer.train()
|
trainer.train()
|
||||||
except Exception:
|
except Exception:
|
||||||
# Handle any exceptions
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Resume from latest checkpoint
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||||
checkpoint = Checkpoint.load(load_dir)
|
trainer = Trainer(train_config)
|
||||||
trainer.train(checkpoint)
|
trainer.train(resume_dir=load_dir)
|
||||||
|
|
||||||
|
# Verify checkpoint was saved at expected iteration
|
||||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||||
checkpoint = Checkpoint.load(load_dir)
|
import json
|
||||||
assert checkpoint.iteration == 10
|
|
||||||
|
with open(os.path.join(load_dir, "meta.json")) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert meta["iteration"] == 10
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
||||||
|
|
||||||
for batch_per_device in batch_sizes:
|
for batch_per_device in batch_sizes:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
|
|
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
|
||||||
|
|
||||||
for grad_accum_steps in grad_accum_steps_list:
|
for grad_accum_steps in grad_accum_steps_list:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
|
|
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
||||||
|
|
||||||
for config in small_batch_configs:
|
for config in small_batch_configs:
|
||||||
train_config = train_config_factory(
|
train_config = train_config_factory(
|
||||||
model=base_test_env["model"],
|
model_fn=lambda: base_test_env["model"],
|
||||||
dataset=random_dataset,
|
dataset=random_dataset,
|
||||||
test_dir=base_test_env["test_dir"],
|
test_dir=base_test_env["test_dir"],
|
||||||
device=base_test_env["device"],
|
device=base_test_env["device"],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue