Source code for renate.updaters.experimental.l2p

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torchmetrics
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from renate import defaults
from renate.benchmark.models.l2p import LearningToPromptTransformer
from renate.models import RenateModule
from renate.types import NestedTensors
from renate.updaters.experimental.offline_er import OfflineExperienceReplayLearner
from renate.updaters.learner import Learner
from renate.updaters.model_updater import SingleTrainingLoopUpdater
from renate.utils.misc import maybe_populate_mask_and_ignore_logits

logger = logging.getLogger(__name__)


[docs] class LearningToPromptLearner(Learner): """Learner for learning to prompt This is identical to the base learner with an addition of loss term. TODO: Make this loss a component. Args: prompt_sim_loss_weight: Loss weight for the prompt key - image representation similarity """ def __init__( self, prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, **kwargs, ) -> None: assert isinstance( kwargs["model"], LearningToPromptTransformer ), f"{self.__class__.__name__} can only train a LearningToPromptTransformer model" f"but got {type(kwargs['model'])}" super().__init__( **kwargs, ) self.prompt_sim_loss_weight = prompt_sim_loss_weight self._loss_collections["train_losses"].update({"key_sim": torchmetrics.MeanMetric()})
[docs] def training_step( self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int ) -> STEP_OUTPUT: loss_dict = super().training_step(batch, batch_idx=batch_idx) key_similarity = -1 * self.prompt_sim_loss_weight * self._model.similarity_score loss_dict["loss"] += key_similarity self._loss_collections["train_losses"]["key_sim"](-key_similarity) return loss_dict
[docs] class LearningToPromptReplayLearner(OfflineExperienceReplayLearner): """L2P with an off-line ER learner. The model will be trained on weighted mixture of losses computed on the new data and a replay buffer. In contrast to the online version, the buffer will only be updated after training has terminated. Args: prompt_sim_loss_weight: Loss weight for the prompt key - image representation similarity """ def __init__( self, prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, **kwargs, ) -> None: assert isinstance( kwargs["model"], LearningToPromptTransformer ), f"{self.__class__.__name__} can only train a LearningToPromptTransformer model" f"but got {type(kwargs['model'])}" super().__init__(**kwargs) self.prompt_sim_loss_weight = prompt_sim_loss_weight self._loss_collections["train_losses"].update({"key_sim_loss": torchmetrics.MeanMetric()})
[docs] def training_step( self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int ) -> STEP_OUTPUT: """PyTorch Lightning function to return the training loss.""" # The reason for rewriting is to ensure two independent forward props of inputs and memory # samples. LearningToPromptTransformer uses per_batch_prompt which uses a single prompt # repeated across the batch. Hence, the separate processing of memory and input samples. if self._loss_weight_new_data is None: alpha = self._num_points_current_task / ( self._num_points_current_task + self._num_points_previous_tasks ) else: alpha = self._loss_weight_new_data alpha = torch.tensor(alpha, device=batch["current_task"][0].device) inputs, targets = batch["current_task"] outputs = self(inputs) outputs, self._class_mask = maybe_populate_mask_and_ignore_logits( self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs, ) if "memory" in batch: (inputs_mem, targets_mem), _ = batch["memory"] outputs_mem = self(inputs_mem) outputs_mem, self._class_mask = maybe_populate_mask_and_ignore_logits( self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs_mem, ) loss_current = self._loss_fn(outputs, targets).mean() if "memory" in batch: loss_memory = self._loss_fn(outputs_mem, targets_mem).mean() self._loss_collections["train_losses"]["base_loss"](loss_current) self._loss_collections["train_losses"]["memory_loss"](loss_memory) loss = alpha * loss_current + (1.0 - alpha) * loss_memory else: loss = loss_current.mean() self._loss_collections["train_losses"]["base_loss"](loss) self._update_metrics(outputs, targets, "train") key_similarity = -1 * self.prompt_sim_loss_weight * self._model.similarity_score loss += key_similarity self._loss_collections["train_losses"]["key_sim"](-key_similarity) return {"loss": loss}
[docs] class LearningToPromptModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, optimizer: Callable[[List[nn.Parameter]], Optimizer], batch_size: int = defaults.BATCH_SIZE, seed: int = defaults.SEED, learner_kwargs: Optional[Dict[str, Any]] = None, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, max_epochs: int = defaults.MAX_EPOCHS, learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, test_target_transform: Optional[Callable] = None, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, metric: Optional[str] = None, mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, early_stopping_enabled: bool = False, logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, "seed": seed, "loss_fn": loss_fn, "prompt_sim_loss_weight": prompt_sim_loss_weight, } super().__init__( model=model, loss_fn=loss_fn, optimizer=optimizer, learner_class=LearningToPromptLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, learning_rate_scheduler=learning_rate_scheduler, learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, test_target_transform=test_target_transform, buffer_transform=buffer_transform, buffer_target_transform=buffer_target_transform, metric=metric, mode=mode, logged_metrics=logged_metrics, early_stopping_enabled=early_stopping_enabled, logger=logger, accelerator=accelerator, devices=devices, strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, mask_unused_classes=mask_unused_classes, )
[docs] class LearningToPromptReplayModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, batch_memory_frac: float = defaults.BATCH_MEMORY_FRAC, loss_weight_new_data: Optional[float] = None, learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, max_epochs: int = defaults.MAX_EPOCHS, train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, test_target_transform: Optional[Callable] = None, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, metric: Optional[str] = None, mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, early_stopping_enabled: bool = False, logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, strategy: str = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, "batch_memory_frac": batch_memory_frac, "loss_weight_new_data": loss_weight_new_data, "batch_size": batch_size, "seed": seed, "prompt_sim_loss_weight": prompt_sim_loss_weight, } super().__init__( model, loss_fn=loss_fn, optimizer=optimizer, learner_class=LearningToPromptReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, learning_rate_scheduler=learning_rate_scheduler, learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, test_target_transform=test_target_transform, buffer_transform=buffer_transform, buffer_target_transform=buffer_target_transform, metric=metric, mode=mode, logged_metrics=logged_metrics, early_stopping_enabled=early_stopping_enabled, logger=logger, accelerator=accelerator, devices=devices, strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, mask_unused_classes=mask_unused_classes, )