diff --git a/astrai/config/base.py b/astrai/config/base.py index 0e71578..7758d54 100644 --- a/astrai/config/base.py +++ b/astrai/config/base.py @@ -89,10 +89,10 @@ class BaseConfig: raise TypeError @classmethod - def from_json(cls, path: Union[str, Path]) -> Self: + def from_file(cls, path: Union[str, Path]) -> Self: with open(path, "r", encoding="utf-8") as f: return cls.from_dict(json.load(f)) - def to_json(self, path: Union[str, Path]): + def to_file(self, path: Union[str, Path]): with open(path, "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index ad016bc..ceb12db 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -1,6 +1,5 @@ -import json from dataclasses import dataclass -from typing import Any, Dict, Optional, Self +from typing import Any, Dict, Optional from astrai.config.base import BaseConfig from astrai.factory import BaseFactory @@ -22,18 +21,6 @@ class BaseModelConfig(BaseConfig): model_type: Optional[str] = None - @classmethod - def from_file(cls, config_path: str) -> Self: - with open(config_path, "r") as f: - raw: Dict[str, Any] = json.load(f) - return cls.from_dict(raw) - - def to_file(self, config_path: str): - d = self.to_dict() - config_dict = {k: v for k, v in d.items() if v is not None} - with open(config_path, "w") as f: - json.dump(config_dict, f, indent=4) - @dataclass @ConfigFactory.register("autoregressive_lm") diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index bbb8c70..34e161d 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -136,20 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]): dataset = DatasetFactory.create("custom", window_size, stride) """ - @classmethod - def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset": - """Create a dataset instance. - - Args: - train_type: Type of training ("seq", "sft", "dpo", "grpo") - window_size: Window size for data sampling - stride: Stride between consecutive samples - - Returns: - Dataset instance - """ - return super().create(train_type, window_size, stride) - @classmethod def load( cls, @@ -179,11 +165,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]): return dataset - @classmethod - def available_types(cls) -> list: - """Return list of registered dataset type names.""" - return cls.list_registered() - @DatasetFactory.register("seq") class SEQDataset(BaseDataset): diff --git a/astrai/factory.py b/astrai/factory.py index 45178ca..8c27955 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -8,8 +8,6 @@ from typing import ( Dict, ForwardRef, Generic, - Optional, - Tuple, Type, TypeVar, ) @@ -56,21 +54,19 @@ class BaseFactory(ABC, Generic[T]): unrelated parameters. """ - _entries: Dict[str, Tuple[Type, Optional[str], int]] + _entries: Dict[str, Type[T]] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - cls._entries = {} for orig_base in getattr(cls, "__orig_bases__", ()): if _get_origin(orig_base) is BaseFactory: (arg,) = _get_args(orig_base) + cls._entries = {} cls._component_base = _resolve_type(arg, cls) return @classmethod - def register( - cls, name: str, category: Optional[str] = None, priority: int = 0 - ) -> Callable[[Type[T]], Type[T]]: + def register(cls, name: str) -> Callable[[Type[T]], Type[T]]: """Decorator to register a component class. Validates that the decorated class inherits from the generic @@ -81,7 +77,7 @@ class BaseFactory(ABC, Generic[T]): cls._validate_component(component_cls) if name in cls._entries: raise ValueError(f"Component '{name}' is already registered") - cls._entries[name] = (component_cls, category, priority) + cls._entries[name] = component_cls return component_cls return decorator @@ -96,7 +92,7 @@ class BaseFactory(ABC, Generic[T]): raise ValueError( f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}" ) - component_cls = entry[0] + component_cls = entry sig = inspect.signature(component_cls.__init__) has_var_kwargs = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() @@ -130,7 +126,7 @@ class BaseFactory(ABC, Generic[T]): raise ValueError( f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}" ) - return entry[0] + return entry @classmethod def list_registered(cls) -> list: diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py index e7e7e7e..fbc6827 100644 --- a/astrai/inference/api/anthropic.py +++ b/astrai/inference/api/anthropic.py @@ -42,7 +42,6 @@ class AnthropicResponseBuilder(ResponseBuilder): resp_id=f"msg_{uuid.uuid4().hex[:24]}", created=int(time.time()), model=request.model, - prompt_tokens=0, ) stop_sequences = getattr(request, "stop_sequences", None) or [] return prompt, ctx, stop_sequences diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py index f3fe27a..2078797 100644 --- a/astrai/inference/api/openai.py +++ b/astrai/inference/api/openai.py @@ -86,7 +86,6 @@ class OpenAIResponseBuilder(ResponseBuilder): resp_id=self._resp_id, created=int(time.time()), model=self._model, - prompt_tokens=0, ) stop = request.stop stop_sequences = ( diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 627a97e..2fb7726 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -35,7 +35,7 @@ class GenContext: resp_id: str created: int model: str - prompt_tokens: int + prompt_tokens: int = 0 completion_tokens: int = 0 diff --git a/astrai/model/components/attention.py b/astrai/model/components/attention.py index dc27a7a..3cf00a7 100644 --- a/astrai/model/components/attention.py +++ b/astrai/model/components/attention.py @@ -24,9 +24,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor: class AttnFactory(BaseFactory[nn.Module]): - @classmethod - def create(cls, attn_type: str, **kwargs) -> nn.Module: - return super().create(attn_type, **kwargs) + pass @AttnFactory.register("gqa") diff --git a/astrai/model/components/mlp.py b/astrai/model/components/mlp.py index e99ee51..0270bad 100644 --- a/astrai/model/components/mlp.py +++ b/astrai/model/components/mlp.py @@ -8,9 +8,7 @@ from astrai.model.components.linear import Linear class FFNFactory(BaseFactory[nn.Module]): - @classmethod - def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module: - return super().create(ffn_type, dim, dim_ffn, **kwargs) + pass @FFNFactory.register("mlp") diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index c1017ac..6fef201 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -44,7 +44,7 @@ class Pipeline: Usage:: - config = PipelineConfig.from_json("sft_pipeline.json") + config = PipelineConfig.from_file("sft_pipeline.json") Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run() """ diff --git a/astrai/trainer/schedule.py b/astrai/trainer/schedule.py index c3bd071..5c0225e 100644 --- a/astrai/trainer/schedule.py +++ b/astrai/trainer/schedule.py @@ -31,7 +31,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]): """Factory class for creating learning rate schedulers. Supports decorator-based registration for extensible scheduler types. - Also supports creation from ScheduleConfig objects. Example usage: @SchedulerFactory.register("custom") @@ -41,27 +40,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]): scheduler = SchedulerFactory.create("custom", optimizer, **kwargs) """ - @classmethod - def create( - cls, optimizer, schedule_type: str = "none", **kwargs - ) -> "BaseScheduler": - """Create a scheduler instance by type name. - - Args: - optimizer: PyTorch optimizer - schedule_type: Type of scheduler ("cosine", "sgdr") - **kwargs: Arguments passed to the scheduler constructor - - Returns: - Scheduler instance - """ - return super().create(schedule_type, optimizer, **kwargs) - - @classmethod - def available_types(cls) -> list: - """Return list of registered scheduler type names.""" - return cls.list_registered() - # ----------- Scheduler implementations ----------- diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 4cdcaca..1eb7e02 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -127,26 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]): strategy = StrategyFactory.create("custom", model, device) """ - @classmethod - def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy": - """Create a strategy instance based on training type. - - Args: - train_type: Type of training ("seq", "sft", "dpo", "grpo") - model: Model instance for the strategy - device: Device to run the strategy on - **kwargs: Additional arguments passed to strategy constructor - - Returns: - Strategy instance - """ - return super().create(train_type, model, device, **kwargs) - - @classmethod - def available_strategies(cls) -> list: - """Return list of registered strategy names.""" - return cls.list_registered() - # ============== Strategy Classes ============== # All strategies are registered at class definition time using the decorator diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index e172097..6af33a9 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -172,8 +172,8 @@ class TrainContextBuilder: obj.load_state_dict(extra[name]) context.strategy = StrategyFactory.create( + cfg.strategy, model=context.model, - train_type=cfg.strategy, device=device, executor=executor, model_fn=cfg.model_fn, diff --git a/scripts/tools/preprocess.py b/scripts/tools/preprocess.py index 56cb82d..4e19f42 100644 --- a/scripts/tools/preprocess.py +++ b/scripts/tools/preprocess.py @@ -24,7 +24,7 @@ def main(): ) args = parser.parse_args() - config = PipelineConfig.from_json(args.config) + config = PipelineConfig.from_file(args.config) Pipeline( config=config, diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 6f2612b..d2c8eec 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -231,7 +231,8 @@ def create_optimizer(model, **kwargs) -> optim.Optimizer: def create_scheduler( optimizer: optim.Optimizer, **kwargs ) -> optim.lr_scheduler.LRScheduler: - return SchedulerFactory.create(optimizer, **kwargs) + schedule_type = kwargs.pop("schedule_type") + return SchedulerFactory.create(schedule_type, optimizer, **kwargs) def compute_total_steps( diff --git a/tests/data/test_preprocess_config.py b/tests/data/test_preprocess_config.py index 972be9e..55f37d3 100644 --- a/tests/data/test_preprocess_config.py +++ b/tests/data/test_preprocess_config.py @@ -53,15 +53,15 @@ def test_to_dict_roundtrip(): assert config2.mask == {"prompt": "mask", "response": "train"} -def test_to_json_from_json(temp_dir): +def test_to_file_from_file(temp_dir): config = PipelineConfig( input=InputConfig(sections=_TEXT_SECTIONS), mask={"text": "train"}, mask_default="mask", ) path = os.path.join(temp_dir, "config.json") - config.to_json(path) - loaded = PipelineConfig.from_json(path) + config.to_file(path) + loaded = PipelineConfig.from_file(path) assert loaded.input.sections == _TEXT_SECTIONS assert loaded.mask == {"text": "train"} @@ -69,8 +69,8 @@ def test_to_json_from_json(temp_dir): def test_dpo_config_roundtrip(temp_dir): config = make_dpo_chat_config() path = os.path.join(temp_dir, "config.json") - config.to_json(path) - loaded = PipelineConfig.from_json(path) + config.to_file(path) + loaded = PipelineConfig.from_file(path) assert loaded.input.sources is not None assert "chosen" in loaded.input.sources assert "rejected" in loaded.input.sources diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py index 5ce6c51..e7cca22 100644 --- a/tests/trainer/conftest.py +++ b/tests/trainer/conftest.py @@ -65,7 +65,7 @@ def create_train_config( def scheduler_fn(optim): return SchedulerFactory.create( - optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05 ) return TrainConfig( diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index b604a9c..1ca8d66 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -102,7 +102,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase def scheduler_fn(optim): return SchedulerFactory.create( - optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05 ) train_config = TrainConfig( @@ -136,7 +136,7 @@ def test_callback_integration(base_test_env, random_dataset): def scheduler_fn(optim): return SchedulerFactory.create( - optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05 ) train_config = TrainConfig( diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 2047d7f..729d069 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -16,7 +16,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def scheduler_fn(optim): return SchedulerFactory.create( - optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + "cosine", optim, warmup_steps=10, lr_decay_steps=10, min_rate=0.05 ) train_config = TrainConfig( diff --git a/tests/trainer/test_train_strategy.py b/tests/trainer/test_train_strategy.py index de43b2b..cb83dab 100644 --- a/tests/trainer/test_train_strategy.py +++ b/tests/trainer/test_train_strategy.py @@ -36,8 +36,8 @@ def test_schedule_factory_random_configs(): min_rate = params["min_rate"] lr_decay_steps = total_steps - warmup_steps scheduler = SchedulerFactory.create( - optimizer, schedule_type, + optimizer, warmup_steps=warmup_steps, lr_decay_steps=lr_decay_steps, min_rate=min_rate, @@ -52,8 +52,8 @@ def test_schedule_factory_random_configs(): t_mult = params["t_mult"] min_rate = params["min_rate"] scheduler = SchedulerFactory.create( - optimizer, schedule_type, + optimizer, warmup_steps=warmup_steps, cycle_length=cycle_length, t_mult=t_mult, @@ -103,8 +103,8 @@ def test_schedule_factory_edge_cases(): min_rate = params["min_rate"] lr_decay_steps = total_steps - warmup_steps scheduler = SchedulerFactory.create( - optimizer, "cosine", + optimizer, warmup_steps=warmup_steps, lr_decay_steps=lr_decay_steps, min_rate=min_rate, @@ -129,8 +129,8 @@ def test_schedule_factory_state_persistence(): min_rate = 0.1 lr_decay_steps = total_steps - warmup_steps scheduler = SchedulerFactory.create( - optimizer, "cosine", + optimizer, warmup_steps=warmup_steps, lr_decay_steps=lr_decay_steps, min_rate=min_rate, @@ -146,8 +146,8 @@ def test_schedule_factory_state_persistence(): # Create new scheduler with same parameters new_scheduler = SchedulerFactory.create( - optimizer, "cosine", + optimizer, warmup_steps=warmup_steps, lr_decay_steps=lr_decay_steps, min_rate=min_rate,