renate.models package#
- class renate.models.RenateModule(constructor_arguments)[source]#
Bases:
Module,ABCA class for torch models with some additional functionality for continual learning.
RenateModulederives fromtorch.nn.Moduleand provides some additional functionality relevant to continual learning. In particular, this concerns saving and reloading the model when model hyperparameters (which might affect the architecture) change during hyperparameter optimization. There is also functionality to retrieve internal-layer representations for use in replay-based CL methods.When implementing a subclass of
RenateModule, make sure to call the base class’ constructor and provide your model’s constructor arguments. Besides that, you can define aRenateModulejust liketorch.nn.Module.Example:
class MyMNISTMLP(RenateModule): def __init__(self, num_hidden: int): super().__init__( constructor_arguments={"num_hidden": num_hidden} loss_fn=torch.nn.CrossEntropyLoss() ) self._fc1 = torch.nn.Linear(28*28, num_hidden) self._fc2 = torch.nn.Linear(num_hidden, 10) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._fc1(x) x = torch.nn.functional.relu(x) return self._fc2(x)
The state of a
RenateModulecan be retrieved via theRenateModule.state_dict()method, just as intorch.nn.Module. When reloading aRenateModulefrom a stored state dict, useRenateModule.from_state_dict. It wil automatically recover the hyperparameters and reinstantiate your model accordingly.Note: Some methods of
RenateModuleaccept an optionaltask_idargument. This is in anticipation of future methods for continual learning scenarios where task identifiers are provided. It is currently not used.- Parameters:
constructor_arguments¶ (
dict) – Arguments needed to instantiate the model.
- classmethod from_state_dict(state_dict)[source]#
Load the model from a state dict.
- Parameters:
state_dict¶ – The state dict of the model. This method works under the assumption that this has been created by
RenateModule.state_dict().
- get_extra_state(encode=True)[source]#
Get the constructor_arguments, and task ids necessary to reconstruct the model.
- Return type:
Any
- set_extra_state(state, decode=True)[source]#
Extract the content of the
_extra_stateand set the related values in the module.
- abstract forward(x, task_id=None)[source]#
Performs a forward pass on the inputs and returns the predictions.
This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.
- Parameters:
x¶ (
Union[Tensor,Tuple[Union[Tensor,Tuple[NestedTensors],Dict[str, NestedTensors]]],Dict[str,Union[Tensor,Tuple[NestedTensors],Dict[str, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.task_id¶ (
Optional[str]) – The identifier of the task for which predictions are made.
- Return type:
Tensor- Returns:
The model’s predictions.
- get_params(task_id=None)[source]#
User-facing function which returns the list of parameters.
If a
task_idis given, this should return only parameters used for the specific task.- Parameters:
task_id¶ (
Optional[str]) – The task id for which we want to retrieve parameters.- Return type:
List[Parameter]
- add_task_params(task_id=None)[source]#
Adds new parameters, associated to a specific task, to the model.
This function should not be overwritten; use
_add_task_paramsinstead.- Parameters:
task_id¶ (
Optional[str]) – The task id for which the new parameters are added.- Return type:
None
- get_logits(x, task_id=None)[source]#
Returns the logits for a given pair of input and task id.
By default, this method returns the output of the forward pass. This may be overwritten with custom behavior, if necessary.
- Parameters:
x¶ (
Union[Tensor,Tuple[Union[Tensor,Tuple[NestedTensors],Dict[str, NestedTensors]]],Dict[str,Union[Tensor,Tuple[NestedTensors],Dict[str, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.task_id¶ (
Optional[str]) – The task id.
- Return type:
Tensor
- get_intermediate_representation()[source]#
Returns the cached intermediate representation.
- Return type:
List[Tensor]
- replace_batch_norm_with_continual_norm(num_groups=32)[source]#
Replaces every occurence of batch normalization with continual normalization.
Pham, Q., Liu, C., & Hoi, S. (2022). Continual normalization: Rethinking batch normalization for online continual learning. arXiv preprint arXiv:2203.16102.
- Parameters:
num_groups¶ (
int) – Number of groups when considering the group normalization in continual normalization.- Return type:
None
Subpackages#
Submodules#
- renate.models.prediction_strategies module
- renate.models.renate_module module
RenateModuleRenateModule.from_state_dict()RenateModule.get_extra_state()RenateModule.set_extra_state()RenateModule.forward()RenateModule.get_params()RenateModule.add_task_params()RenateModule.get_logits()RenateModule.get_intermediate_representation()RenateModule.replace_batch_norm_with_continual_norm()RenateModule.register_intermediate_representation_caching_hook()RenateModule.deregister_hooks()RenateModule.reset_intermediate_representation_cache()
RenateWrapper
- renate.models.task_identification_strategies module