# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import abc
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
import torchmetrics
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset
from renate import defaults
from renate.data.datasets import _TransformedDataset
from renate.memory import DataBuffer, InfiniteBuffer, ReservoirBuffer
from renate.models import RenateModule
from renate.types import NestedTensors
from renate.utils.misc import maybe_populate_mask_and_ignore_logits
from renate.utils.pytorch import get_generator, unique_classes
[docs]
class RenateLightningModule(LightningModule, abc.ABC):
"""Base class for LightningModules, which implement metric logging and basic training logic.
The `RenateLightningModule` is a `LightningModule`, but provides additional hook functions
called by `ModelUpdater`. These hooks are:
- `on_model_update_start`, which is called in the beginning of a
model update. We expect this to return train and (optionally) validation
data loader(s).
- `on_model_update_end`, which is called in the end of a model update.
Args:
model: The model to be trained.
optimizer: Partial optimizer used to create an optimizer by passing the model parameters.
learning_rate_scheduler: Partial object of learning rate scheduler that will be created by
passing the optimizer.
learning_rate_scheduler_interval: When to update the learning rate scheduler.
Options: `epoch` and `step`.
batch_size: Training batch size.
logged_metrics: Metrics logged additional to the default ones.
seed: See :func:`renate.models.utils.get_generator`.
mask_unused_classes: Flag to use if logits corresponding to unused classes are to be ignored
in the loss computation. Possibly useful for class incremental learning.
"""
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
seed: int = defaults.SEED,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
) -> None:
super().__init__()
self._model = model
self._loss_fn = loss_fn
self._optimizer = optimizer
self._learning_rate_scheduler = learning_rate_scheduler
self._learning_rate_scheduler_interval = learning_rate_scheduler_interval
self._batch_size = batch_size
self._seed = seed
self._mask_unused_classes = mask_unused_classes
self._class_mask = None
self._classes_in_current_task = None
self._task_id: str = defaults.TASK_ID
self._train_dataset: Optional[Dataset] = None
self._val_dataset: Optional[Dataset] = None
self.val_enabled = False
self._train_collate_fn: Optional[Callable] = None
self._val_collate_fn: Optional[Callable] = None
self._create_metrics_collections(logged_metrics)
self._rng = get_generator(self._seed)
self.save_hyperparameters(ignore=self._ignored_hyperparameters())
def _ignored_hyperparameters(self):
"""Hyperparameters to be ignored in the ``save_hyperparameters`` call."""
return [
"model",
"loss_fn",
"optimizer",
"learning_rate_scheduler",
"logged_metrics",
]
def _create_metrics_collections(
self, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None
) -> None:
"""Creates all logged metrics."""
if logged_metrics is None:
logged_metrics = {}
metrics = torchmetrics.MetricCollection(logged_metrics)
train_metrics = metrics.clone(prefix="train_")
val_metrics = metrics.clone(prefix="val_")
train_losses = nn.ModuleDict(
{
"base_loss": torchmetrics.MeanMetric(),
"loss": torchmetrics.MeanMetric(),
}
)
val_losses = nn.ModuleDict({"loss": torchmetrics.MeanMetric()})
self._metric_collections = nn.ModuleDict(
{
"train_metrics": train_metrics,
"val_metrics": val_metrics,
}
)
self._loss_collections = nn.ModuleDict(
{
"train_losses": train_losses,
"val_losses": val_losses,
}
)
[docs]
def is_logged_metric(self, metric_name: str) -> bool:
"""Returns `True` if there is a metric with name `metric_name`."""
if metric_name is None:
return True
logged_metrics = list()
for prefix in ["train", "val"]:
for collection, collection_name in zip(
[self._metric_collections, self._loss_collections], ["metrics", "losses"]
):
collection_key = f"{prefix}_{collection_name}"
if collection_key in collection:
logged_metrics += [
f"{prefix}_{logged_metric_name}"
for logged_metric_name in collection[collection_key]
]
return metric_name in logged_metrics
[docs]
def on_model_update_start(
self,
train_dataset: Dataset,
val_dataset: Dataset,
train_dataset_collate_fn: Optional[Callable] = None,
val_dataset_collate_fn: Optional[Callable] = None,
task_id: Optional[str] = None,
) -> None:
self._train_dataset = train_dataset
self._val_dataset = val_dataset
self.val_enabled = val_dataset is not None and len(val_dataset) > 0
self._train_collate_fn = train_dataset_collate_fn
self._val_collate_fn = val_dataset_collate_fn
self._task_id = task_id
self._model.add_task_params(task_id=self._task_id)
if self._mask_unused_classes:
# The first forward prop will populate the _class_mask with the following
# unique classes
self._classes_in_current_task = unique_classes(self._train_dataset)
[docs]
def train_dataloader(self) -> DataLoader:
"""Returns the dataloader for training the model."""
return DataLoader(
self._train_dataset,
batch_size=self._batch_size,
shuffle=True,
generator=self._rng,
pin_memory=True,
collate_fn=self._train_collate_fn,
)
[docs]
def val_dataloader(self) -> Optional[DataLoader]:
if self._val_dataset is not None:
return DataLoader(
self._val_dataset,
batch_size=self._batch_size,
shuffle=False,
generator=self._rng,
pin_memory=True,
collate_fn=self._val_collate_fn,
)
[docs]
def on_model_update_end(self) -> None:
"""Called right before a model update terminates."""
pass
[docs]
def forward(self, inputs: NestedTensors, task_id: Optional[str] = None) -> torch.Tensor:
"""Forward pass of the model."""
if task_id is None:
task_id = self._task_id
return self._model(inputs, task_id=task_id)
[docs]
def training_step_unpack_batch(self, batch: Tuple[Any, Any]) -> Tuple[Any, Any]:
inputs, targets = batch
return inputs, targets
[docs]
def training_step(
self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int
) -> STEP_OUTPUT:
"""PyTorch Lightning function to return the training loss."""
inputs, targets = self.training_step_unpack_batch(batch)
outputs = self(inputs)
outputs, self._class_mask = maybe_populate_mask_and_ignore_logits(
self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs
)
intermediate_representation = self._model.get_intermediate_representation()
self._model.reset_intermediate_representation_cache()
loss = self._loss_fn(outputs, targets).mean()
self._update_metrics(outputs, targets, "train")
self._loss_collections["train_losses"]["base_loss"](loss)
return {
"loss": loss,
"outputs": outputs,
"intermediate_representation": intermediate_representation,
}
[docs]
def training_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT:
"""PyTorch Lightning function to perform after the training step."""
super().training_step_end(step_output)
self._loss_collections["train_losses"]["loss"](step_output["loss"])
return step_output
[docs]
def training_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]]) -> None:
"""PyTorch Lightning function to run at the end of training epoch."""
super().training_epoch_end(outputs)
if not self.val_enabled:
self._log_metrics()
[docs]
def validation_step_unpack_batch(self, batch: Tuple[Tuple[Any, Any], Any]) -> Tuple[Any, Any]:
(inputs, targets), _ = batch
return inputs, targets
[docs]
def validation_step(self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int) -> None:
"""PyTorch Lightning function to estimate validation metrics."""
inputs, targets = self.validation_step_unpack_batch(batch)
outputs = self(inputs)
loss = self._loss_fn(outputs, targets)
self._update_metrics(outputs, targets, "val")
self._loss_collections["val_losses"]["loss"](loss)
[docs]
def validation_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]]) -> None:
"""PyTorch Lightning function to run at the end of validation epoch."""
super().validation_epoch_end(outputs)
self._log_metrics()
def _update_metrics(
self,
outputs: torch.Tensor,
y: torch.Tensor,
prefix: Literal["train", "val"],
) -> None:
"""Shared logic for updating metrics."""
self._metric_collections[f"{prefix}_metrics"](outputs, y)
def _log_metrics(
self,
) -> None:
"""Shared logic for logging metrics, including the loss."""
if self.trainer.sanity_checking:
return
prefixes = ["train", "val"] if self.val_enabled else ["train"]
for prefix in prefixes:
self.log_dict(
self._metric_collections[f"{prefix}_metrics"].compute(),
on_step=False,
on_epoch=True,
logger=True,
sync_dist=True,
)
self._metric_collections[f"{prefix}_metrics"].reset()
for loss_name, loss in self._loss_collections[f"{prefix}_losses"].items():
self.log(
f"{prefix}_{loss_name}",
loss.compute(),
on_step=False,
on_epoch=True,
logger=True,
sync_dist=True,
)
loss.reset()
[docs]
class Learner(RenateLightningModule, abc.ABC):
"""Base class for Learners, which encapsulate the core CL methodologies.
The `Learner` is a `LightningModule`, but provides additional hook functions
called by `ModelUpdater`. These hooks are:
- `Learner.on_model_update_start`, which is called in the beginning of a
model update. We expect this to return train and (optionally) validation
data loader(s).
- `Learner.on_model_update_end`, which is called in the end of a model update.
This base class implements a basic training loop without any mechanism to
counteract forgetting.
Args:
model: The model to be trained.
optimizer: Partial optimizer used to create an optimizer by passing the model parameters.
learning_rate_scheduler: Partial object of learning rate scheduler that will be created by
passing the optimizer.
learning_rate_scheduler_interval: When to update the learning rate scheduler.
Options: `epoch` and `step`.
batch_size: Training batch size.
train_transform: The transformation applied during training.
train_target_transform: The target transformation applied during testing.
test_transform: The transformation at test time.
test_target_transform: The target transformation at test time.
logged_metrics: Metrics logged additional to the default ones.
seed: See :func:`renate.models.utils.get_generator`.
"""
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
train_transform: Optional[Callable] = None,
train_target_transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
test_target_transform: Optional[Callable] = None,
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
seed: int = defaults.SEED,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
) -> None:
super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
learning_rate_scheduler=learning_rate_scheduler,
learning_rate_scheduler_interval=learning_rate_scheduler_interval,
batch_size=batch_size,
logged_metrics=logged_metrics,
seed=seed,
mask_unused_classes=mask_unused_classes,
)
self._train_transform = train_transform
self._train_target_transform = train_target_transform
self._test_transform = test_transform
self._test_target_transform = test_target_transform
self._val_memory_buffer: DataBuffer = InfiniteBuffer()
def _ignored_hyperparameters(self):
"""Hyperparameters to be ignored in the ``save_hyperparameters`` call."""
return super()._ignored_hyperparameters() + [
"components",
"train_transform",
"train_target_transform",
"test_transform",
"test_target_transform",
"buffer_transform",
"buffer_target_transform",
]
[docs]
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
learner_state_dict = {
"learner_class_name": self.__class__.__name__,
"val_memory_buffer": self._val_memory_buffer.state_dict(),
}
checkpoint.update(learner_state_dict)
[docs]
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
self._val_memory_buffer.load_state_dict(checkpoint["val_memory_buffer"])
[docs]
def save(self, output_state_dir: str) -> None:
val_buffer_dir = os.path.join(output_state_dir, "val_memory_buffer")
os.makedirs(val_buffer_dir, exist_ok=True)
self._val_memory_buffer.save(val_buffer_dir)
[docs]
def load(self, input_state_dir: str) -> None:
self._val_memory_buffer.load(os.path.join(input_state_dir, "val_memory_buffer"))
[docs]
def on_model_update_start(
self,
train_dataset: Dataset,
val_dataset: Dataset,
train_dataset_collate_fn: Optional[Callable] = None,
val_dataset_collate_fn: Optional[Callable] = None,
task_id: Optional[str] = None,
) -> None:
super().on_model_update_start(
train_dataset=train_dataset,
val_dataset=val_dataset,
train_dataset_collate_fn=train_dataset_collate_fn,
val_dataset_collate_fn=val_dataset_collate_fn,
task_id=task_id,
)
self._model.add_task_params(task_id=self._task_id)
[docs]
def train_dataloader(self) -> DataLoader:
"""Returns the dataloader for training the model."""
train_dataset = _TransformedDataset(
self._train_dataset,
transform=self._train_transform,
target_transform=self._train_target_transform,
)
return DataLoader(
train_dataset,
batch_size=self._batch_size,
shuffle=True,
generator=self._rng,
pin_memory=True,
collate_fn=self._train_collate_fn,
)
[docs]
def val_dataloader(self) -> Optional[DataLoader]:
if self._val_dataset is not None:
val_dataset = _TransformedDataset(
self._val_dataset,
transform=self._test_transform,
target_transform=self._test_target_transform,
)
self._val_memory_buffer.update(val_dataset)
if len(self._val_memory_buffer):
return DataLoader(
self._val_memory_buffer,
batch_size=self._batch_size,
shuffle=False,
generator=self._rng,
pin_memory=True,
collate_fn=self._val_collate_fn,
)
[docs]
def validation_step_unpack_batch(
self, batch: Tuple[NestedTensors, torch.Tensor]
) -> Tuple[NestedTensors, Any]:
(inputs, targets), _ = batch
return inputs, targets
[docs]
class ReplayLearner(Learner, abc.ABC):
"""Base class for Learners which use a buffer to store data and reuse it in future updates.
Args:
memory_size: The maximum size of the memory.
batch_memory_frac: Fraction of the batch that is sampled from rehearsal memory.
buffer_transform: The transformation to be applied to the memory buffer data samples.
buffer_target_transform: The target transformation to be applied to the memory buffer target
samples.
seed: See :func:`renate.models.utils.get_generator`.
"""
def __init__(
self,
memory_size: int,
batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: float = defaults.BATCH_MEMORY_FRAC,
buffer_transform: Optional[Callable] = None,
buffer_target_transform: Optional[Callable] = None,
seed: int = defaults.SEED,
**kwargs,
) -> None:
if not (0 <= batch_memory_frac <= 1):
raise ValueError(
f"Expecting batch_memory_frac to be in [0, 1], received {batch_memory_frac}."
)
memory_batch_size = min(memory_size, int(batch_memory_frac * batch_size))
batch_size = batch_size - memory_batch_size
super().__init__(batch_size=batch_size, seed=seed, **kwargs)
self._memory_batch_size = memory_batch_size
self._memory_buffer = ReservoirBuffer(
max_size=memory_size,
seed=seed,
transform=buffer_transform,
target_transform=buffer_target_transform,
)
[docs]
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_save_checkpoint(checkpoint)
checkpoint["memory_buffer"] = self._memory_buffer.state_dict()
[docs]
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
self._memory_buffer.load_state_dict(checkpoint["memory_buffer"])
[docs]
def save(self, output_state_dir: str) -> None:
super().save(output_state_dir)
buffer_dir = os.path.join(output_state_dir, "memory_buffer")
os.makedirs(buffer_dir, exist_ok=True)
self._memory_buffer.save(buffer_dir)
[docs]
def load(self, input_state_dir: str) -> None:
super().load(input_state_dir)
self._memory_buffer.load(os.path.join(input_state_dir, "memory_buffer"))
[docs]
def on_model_update_start(
self,
train_dataset: Dataset,
val_dataset: Dataset,
train_dataset_collate_fn: Optional[Callable] = None,
val_dataset_collate_fn: Optional[Callable] = None,
task_id: Optional[str] = None,
) -> None:
super().on_model_update_start(
train_dataset, val_dataset, train_dataset_collate_fn, val_dataset_collate_fn, task_id
)
if self._mask_unused_classes:
self._classes_in_current_task = self._classes_in_current_task.union(
unique_classes(self._memory_buffer)
)