Source code for renate.updaters.experimental.speft

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn as nn
import torchmetrics
from pytorch_lightning.loggers.logger import Logger
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset

from renate import defaults
from renate.benchmark.models.spromptmodel import SPromptTransformer
from renate.models import RenateModule
from renate.updaters.learner import Learner
from renate.updaters.model_updater import SingleTrainingLoopUpdater


[docs] class SPeftLearner(Learner): """Learner to implement S-Prompts from ```Wang, Yabin, et.al . "S-prompts learning with pre-trained transformers: An occam’s razor for domain incremental learning." # noqa: E501 Advances in Neural Information Processing Systems 35 (2022): 5682-5695.``` Args: model: The SPromptTransformer model to be trained. loss_fn: Loss function to be trained with. optimizer: Partial optimizer used to create an optimizer by passing the model parameters. learning_rate_scheduler: Partial object of learning rate scheduler that will be created by passing the optimizer. learning_rate_scheduler_interval: When to update the learning rate scheduler. Options: `epoch` and `step`. batch_size: Training batch size. train_transform: The transformation applied during training. train_target_transform: The target transformation applied during testing. test_transform: The transformation at test time. test_target_transform: The target transformation at test time. logged_metrics: Metrics logged additional to the default ones. seed: See :func:`renate.models.utils.get_generator`. mask_unused_classes: Masking logits corresponding to unused classes. Useful only for class incremental problems. Defaults to defaults.MASK_UNUSED_CLASSES. """ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, test_target_transform: Optional[Callable] = None, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, seed: int = defaults.SEED, mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ) -> None: if not isinstance(model, SPromptTransformer): raise ValueError( "SPrompt Learner can only be used with a SPromptTransformer model." f"But got {type(model)}" ) super().__init__( model, loss_fn, optimizer, learning_rate_scheduler, learning_rate_scheduler_interval, batch_size, train_transform, train_target_transform, test_transform, test_target_transform, logged_metrics, seed, mask_unused_classes, )
[docs] def on_model_update_start( self, train_dataset: Dataset, val_dataset: Dataset, train_dataset_collate_fn: Optional[Callable] = None, val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> None: """A custom on_model_update_start hook for S-Peft methods. Here, we iterate oer the train data set and extract features. These features used to compute the task prototypes by the `update_task_identifier` call. Having this function in the model update start instead of end results in val metrics being reflective of test accuracy. """ super().on_model_update_start( train_dataset, val_dataset, train_dataset_collate_fn, val_dataset_collate_fn, task_id ) ## k-means device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") self._model.to(device) features, labels = [], [] with torch.inference_mode(): for x, y in self.train_dataloader(): features.append(self._model.features(x.to(device)).cpu()) labels.append(y) features = torch.cat(features) labels = torch.cat(labels) self._model.update_task_identifier(features=features, labels=labels)
[docs] def setup(self, stage: str) -> None: # We dont support distributed assert ( self.trainer.world_size == 1 ), "SPrompt learner does not support Multi-GPU training yet." if stage == "fit": # This needs to run before configure optimizers is called. The only hook is setup("fit") self._model.increment_task()
[docs] def optimizer_zero_grad( self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int ) -> None: """Explicitly setting grads to None instead of zero.""" optimizer.zero_grad(set_to_none=True)
[docs] class SPeftModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, optimizer: Callable[[List[nn.Parameter]], Optimizer], batch_size: int = defaults.BATCH_SIZE, seed: int = defaults.SEED, learner_kwargs: Optional[Dict[str, Any]] = None, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, max_epochs: int = defaults.MAX_EPOCHS, learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, test_target_transform: Optional[Callable] = None, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, metric: Optional[str] = None, mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, early_stopping_enabled: bool = False, logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, "seed": seed, "loss_fn": loss_fn, } super().__init__( model=model, loss_fn=loss_fn, optimizer=optimizer, learner_class=SPeftLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, learning_rate_scheduler=learning_rate_scheduler, learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, test_target_transform=test_target_transform, buffer_transform=buffer_transform, buffer_target_transform=buffer_target_transform, metric=metric, mode=mode, logged_metrics=logged_metrics, early_stopping_enabled=early_stopping_enabled, logger=logger, accelerator=accelerator, devices=devices, strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, mask_unused_classes=mask_unused_classes, )