156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Optional, Self
|
|
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
from astrai.config.train_config import TrainConfig
|
|
from astrai.dataset import ResumableDistributedSampler
|
|
from astrai.model.components.lora import inject_lora
|
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
|
from astrai.serialization import Checkpoint, load_model_weights
|
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
|
|
|
|
|
@dataclass
|
|
class TrainContext:
|
|
model: nn.Module = field(default=None)
|
|
strategy: BaseStrategy = field(default=None)
|
|
dataloader: DataLoader = field(default=None)
|
|
optimizer: OptimizerProtocol = field(default=None)
|
|
scheduler: SchedulerProtocol = field(default=None)
|
|
checkpoint: Checkpoint = field(default=None)
|
|
config: TrainConfig = field(default=None)
|
|
executor: BaseExecutor = field(default=None)
|
|
|
|
epoch: int = field(default=0)
|
|
iteration: int = field(default=0)
|
|
loss: float = field(default=0.0)
|
|
val_dataloader: DataLoader = field(default=None)
|
|
val_loss: float = field(default=0.0)
|
|
|
|
world_size: int = field(default=1)
|
|
rank: int = field(default=0)
|
|
kwargs: dict = field(default_factory=dict)
|
|
|
|
|
|
class TrainContextBuilder:
|
|
def __init__(
|
|
self,
|
|
config: TrainConfig,
|
|
):
|
|
self.config = config
|
|
self._resume_dir: Optional[str] = None
|
|
|
|
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
|
self._resume_dir = resume_dir
|
|
return self
|
|
|
|
def build(self) -> TrainContext:
|
|
cfg = self.config
|
|
device = get_current_device()
|
|
|
|
executor = ExecutorFactory.create(
|
|
cfg.parallel_mode,
|
|
grad_accum_steps=cfg.grad_accum_steps,
|
|
**cfg.executor_kwargs,
|
|
)
|
|
|
|
model = cfg.model_fn()
|
|
model = model.to(device=device)
|
|
|
|
context = TrainContext(
|
|
model=model,
|
|
world_size=get_world_size(),
|
|
rank=get_rank(),
|
|
config=cfg,
|
|
executor=executor,
|
|
)
|
|
|
|
if self._resume_dir is not None:
|
|
resume_path = Path(self._resume_dir)
|
|
if (resume_path / "meta.json").exists():
|
|
checkpoint = Checkpoint.load(self._resume_dir)
|
|
state_dict = checkpoint.state_dict
|
|
else:
|
|
checkpoint = None
|
|
state_dict = load_model_weights(self._resume_dir)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
if checkpoint is not None:
|
|
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
|
|
context.iteration = max(checkpoint.iteration, cfg.start_batch)
|
|
context.checkpoint = checkpoint
|
|
|
|
if cfg.lora is not None:
|
|
inject_lora(
|
|
model,
|
|
r=cfg.lora.r,
|
|
alpha=cfg.lora.alpha,
|
|
target_modules=set(cfg.lora.target_modules),
|
|
)
|
|
|
|
context.optimizer = cfg.optimizer_fn(model)
|
|
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
|
|
|
sampler_offset = context.iteration * cfg.batch_per_device
|
|
sampler = ResumableDistributedSampler(
|
|
data_source=cfg.dataset,
|
|
start_epoch=context.epoch,
|
|
start_iter=sampler_offset,
|
|
seed=cfg.random_seed,
|
|
)
|
|
context.dataloader = DataLoader(
|
|
cfg.dataset,
|
|
batch_size=cfg.batch_per_device,
|
|
sampler=sampler,
|
|
num_workers=cfg.num_workers,
|
|
pin_memory=cfg.pin_memory,
|
|
prefetch_factor=cfg.prefetch_factor,
|
|
)
|
|
|
|
if cfg.val_dataset is not None:
|
|
val_sampler = ResumableDistributedSampler(
|
|
data_source=cfg.val_dataset,
|
|
start_epoch=0,
|
|
start_iter=0,
|
|
seed=cfg.random_seed,
|
|
shuffle=False,
|
|
)
|
|
context.val_dataloader = DataLoader(
|
|
cfg.val_dataset,
|
|
batch_size=cfg.batch_per_device,
|
|
sampler=val_sampler,
|
|
num_workers=cfg.num_workers,
|
|
pin_memory=cfg.pin_memory,
|
|
prefetch_factor=cfg.prefetch_factor,
|
|
)
|
|
|
|
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
|
executor.prepare(
|
|
model,
|
|
context.optimizer,
|
|
context.dataloader,
|
|
context.scheduler,
|
|
)
|
|
)
|
|
|
|
if context.checkpoint and context.checkpoint.extra:
|
|
extra = context.checkpoint.extra
|
|
for name in ("optimizer", "scheduler"):
|
|
if name in extra:
|
|
obj = getattr(context, name, None)
|
|
if obj is not None:
|
|
obj.load_state_dict(extra[name])
|
|
|
|
context.strategy = StrategyFactory.create(
|
|
model=context.model,
|
|
train_type=cfg.strategy,
|
|
device=device,
|
|
**cfg.extra_kwargs,
|
|
)
|
|
|
|
return context
|