Source code for renate.benchmark.models.base
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC
from typing import Any, List, Optional
import torch
from torch import nn
from renate import defaults
from renate.models import RenateModule
from renate.models.prediction_strategies import ICaRLClassificationStrategy, PredictionStrategy
from renate.utils.deepspeed import convert_to_tensor, recover_object_from_tensor
# TODO: merge unit tests for the submodules
[docs]
class RenateBenchmarkingModule(RenateModule, ABC):
"""Base class for all models provided by Renate.
This class ensures that each models works with all ModelUpdaters when using the benchmarking
feature of Renate. New models can extend this class or alternatively extend the RenateModule
and make sure they are compatible with the considered ModelUpdater.
Args:
embedding_size: Representation size of the model after the backbone.
num_outputs: The number of outputs of the model.
constructor_arguments: Arguments needed to instantiate the model.
prediction_strategy: By default a forward pass through the model. Some ModelUpdater must
be combined with specific prediction strategies to work as intended.
add_icarl_class_means: Specific parameters for iCaRL. Can be set to ``False`` if any other
ModelUpdater is used.
"""
def __init__(
self,
embedding_size: int,
num_outputs: int,
constructor_arguments: dict,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
):
constructor_arguments["num_outputs"] = num_outputs
constructor_arguments["add_icarl_class_means"] = add_icarl_class_means
super().__init__(
constructor_arguments=constructor_arguments,
)
self._embedding_size = embedding_size
self._num_outputs = num_outputs
self._prediction_strategy = prediction_strategy
self._tasks_params: torch.nn.ModuleDict = torch.nn.ModuleDict()
self.add_task_params(defaults.TASK_ID)
if add_icarl_class_means:
self.class_means = torch.nn.Parameter(
torch.zeros((embedding_size, num_outputs)), requires_grad=False
)
[docs]
def forward(self, x: torch.Tensor, task_id: str = defaults.TASK_ID) -> torch.Tensor:
x = self.get_backbone(task_id=task_id)(x)
if isinstance(self._prediction_strategy, ICaRLClassificationStrategy):
return self._prediction_strategy(x, self.training, class_means=self.class_means)
else:
assert (
self._prediction_strategy is None
), f"Unknown prediction strategy of type {type(self._prediction_strategy)}."
return self.get_predictor(task_id)(x)
[docs]
def get_backbone(self, task_id: str = defaults.TASK_ID) -> nn.Module:
"""Returns the model without the prediction head."""
return self._backbone
[docs]
def get_predictor(self, task_id: str = defaults.TASK_ID) -> nn.Module:
"""Returns the model without the backbone."""
return self._tasks_params[task_id]
def _add_task_params(self, task_id: str = defaults.TASK_ID) -> None:
"""Adds new parameters associated to a specific task to the model."""
self._tasks_params[task_id] = torch.nn.Linear(self._embedding_size, self._num_outputs)
[docs]
def get_params(self, task_id: str = defaults.TASK_ID) -> List[torch.nn.Parameter]:
"""Returns the list of parameters for the core model and a specific `task_id`."""
return list(self.get_backbone(task_id=task_id).parameters()) + list(
self.get_predictor(task_id=task_id).parameters()
)