Source code for renate.updaters.experimental.gdumb

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torchmetrics
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset

from renate import defaults
from renate.memory import GreedyClassBalancingBuffer
from renate.models import RenateModule
from renate.types import NestedTensors
from renate.updaters.learner import ReplayLearner
from renate.updaters.model_updater import SingleTrainingLoopUpdater


[docs] class GDumbLearner(ReplayLearner): """A Learner that implements the GDumb strategy. Prabhu, Ameya, Philip HS Torr, and Puneet K. Dokania. "GDumb: A simple approach that questions our progress in continual learning." ECCV, 2020. It maintains a memory of previously observed data points and does the training after updating the buffer. Note that, the model is reinitialized before training on the buffer. Args: memory_size: The maximum size of the memory. buffer_transform: The transform to be applied to the data points in the memory. buffer_target_transform: The transform to be applied to the targets in the memory. seed: A random seed. """ def __init__( self, memory_size: int, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, seed: int = defaults.SEED, **kwargs: Any, ) -> None: super().__init__( memory_size=memory_size, seed=seed, **kwargs, ) self._memory_buffer = GreedyClassBalancingBuffer( max_size=memory_size, seed=seed, transform=buffer_transform, target_transform=buffer_target_transform, )
[docs] def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: super().on_load_checkpoint(checkpoint) self._memory_buffer.load_state_dict(checkpoint["memory_buffer"])
[docs] def on_model_update_start( self, train_dataset: Dataset, val_dataset: Dataset, train_dataset_collate_fn: Optional[Callable] = None, val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> None: """Called before a model update starts.""" super().on_model_update_start( train_dataset=train_dataset, val_dataset=val_dataset, train_dataset_collate_fn=train_dataset_collate_fn, val_dataset_collate_fn=val_dataset_collate_fn, task_id=task_id, ) self._memory_buffer.update(train_dataset)
[docs] def train_dataloader(self) -> DataLoader: return DataLoader( self._memory_buffer, batch_size=self._batch_size, shuffle=True, generator=self._rng, pin_memory=True, collate_fn=self._train_collate_fn, )
[docs] def training_step( self, batch: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]], batch_idx: int, ) -> STEP_OUTPUT: """PyTorch Lightning function to return the training loss.""" batch, _ = batch return super().training_step(batch=batch, batch_idx=batch_idx)
[docs] class GDumbModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, max_epochs: int = defaults.MAX_EPOCHS, train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, test_target_transform: Optional[Callable] = None, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, metric: Optional[str] = None, mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, early_stopping_enabled: bool = False, logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, strategy: str = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, "batch_memory_frac": batch_memory_frac, "batch_size": batch_size, "seed": seed, } super().__init__( model, loss_fn=loss_fn, optimizer=optimizer, learner_class=GDumbLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, learning_rate_scheduler=learning_rate_scheduler, learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, test_target_transform=test_target_transform, buffer_transform=buffer_transform, buffer_target_transform=buffer_target_transform, metric=metric, mode=mode, logged_metrics=logged_metrics, early_stopping_enabled=early_stopping_enabled, logger=logger, accelerator=accelerator, devices=devices, strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, mask_unused_classes=mask_unused_classes, )