Source code for renate.updaters.experimental.fine_tuning

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Callable, Dict, List, Optional

import torch
import torchmetrics
from pytorch_lightning.loggers.logger import Logger
from torch.nn import Parameter
from torch.optim import Optimizer

from renate import defaults
from renate.models import RenateModule
from renate.updaters.learner import Learner
from renate.updaters.model_updater import SingleTrainingLoopUpdater


[docs] class FineTuningModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, max_epochs: int = defaults.MAX_EPOCHS, train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, test_target_transform: Optional[Callable] = None, metric: Optional[str] = None, mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, early_stopping_enabled: bool = False, logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, strategy: str = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, "seed": seed, "loss_fn": loss_fn, } super().__init__( model, loss_fn=loss_fn, optimizer=optimizer, learner_class=Learner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, learning_rate_scheduler=learning_rate_scheduler, learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, test_target_transform=test_target_transform, metric=metric, mode=mode, logged_metrics=logged_metrics, early_stopping_enabled=early_stopping_enabled, logger=logger, accelerator=accelerator, devices=devices, deterministic_trainer=deterministic_trainer, strategy=strategy, precision=precision, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, mask_unused_classes=mask_unused_classes, )