renate.models.renate_module module#
- class renate.models.renate_module.RenateModule(constructor_arguments)[source]#
Bases:
Module
,ABC
A class for torch models with some additional functionality for continual learning.
RenateModule
derives fromtorch.nn.Module
and provides some additional functionality relevant to continual learning. In particular, this concerns saving and reloading the model when model hyperparameters (which might affect the architecture) change during hyperparameter optimization. There is also functionality to retrieve internal-layer representations for use in replay-based CL methods.When implementing a subclass of
RenateModule
, make sure to call the base class’ constructor and provide your model’s constructor arguments. Besides that, you can define aRenateModule
just liketorch.nn.Module
.Example:
class MyMNISTMLP(RenateModule): def __init__(self, num_hidden: int): super().__init__( constructor_arguments={"num_hidden": num_hidden} loss_fn=torch.nn.CrossEntropyLoss() ) self._fc1 = torch.nn.Linear(28*28, num_hidden) self._fc2 = torch.nn.Linear(num_hidden, 10) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._fc1(x) x = torch.nn.functional.relu(x) return self._fc2(x)
The state of a
RenateModule
can be retrieved via theRenateModule.state_dict()
method, just as intorch.nn.Module
. When reloading aRenateModule
from a stored state dict, useRenateModule.from_state_dict
. It wil automatically recover the hyperparameters and reinstantiate your model accordingly.Note: Some methods of
RenateModule
accept an optionaltask_id
argument. This is in anticipation of future methods for continual learning scenarios where task identifiers are provided. It is currently not used.- Parameters:
constructor_arguments¶ (
dict
) – Arguments needed to instantiate the model.
- classmethod from_state_dict(state_dict)[source]#
Load the model from a state dict.
- Parameters:
state_dict¶ – The state dict of the model. This method works under the assumption that this has been created by
RenateModule.state_dict()
.
- get_extra_state(encode=True)[source]#
Get the constructor_arguments, and task ids necessary to reconstruct the model.
- Return type:
Any
- set_extra_state(state, decode=True)[source]#
Extract the content of the
_extra_state
and set the related values in the module.
- abstract forward(x, task_id=None)[source]#
Performs a forward pass on the inputs and returns the predictions.
This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.
- Parameters:
x¶ (
Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]],Dict
[str
,Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.task_id¶ (
Optional
[str
]) – The identifier of the task for which predictions are made.
- Return type:
Tensor
- Returns:
The model’s predictions.
- get_params(task_id=None)[source]#
User-facing function which returns the list of parameters.
If a
task_id
is given, this should return only parameters used for the specific task.- Parameters:
task_id¶ (
Optional
[str
]) – The task id for which we want to retrieve parameters.- Return type:
List
[Parameter
]
- add_task_params(task_id=None)[source]#
Adds new parameters, associated to a specific task, to the model.
This function should not be overwritten; use
_add_task_params
instead.- Parameters:
task_id¶ (
Optional
[str
]) – The task id for which the new parameters are added.- Return type:
None
- get_logits(x, task_id=None)[source]#
Returns the logits for a given pair of input and task id.
By default, this method returns the output of the forward pass. This may be overwritten with custom behavior, if necessary.
- Parameters:
x¶ (
Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]],Dict
[str
,Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.task_id¶ (
Optional
[str
]) – The task id.
- Return type:
Tensor
- get_intermediate_representation()[source]#
Returns the cached intermediate representation.
- Return type:
List
[Tensor
]
- replace_batch_norm_with_continual_norm(num_groups=32)[source]#
Replaces every occurence of batch normalization with continual normalization.
Pham, Q., Liu, C., & Hoi, S. (2022). Continual normalization: Rethinking batch normalization for online continual learning. arXiv preprint arXiv:2203.16102.
- Parameters:
num_groups¶ (
int
) – Number of groups when considering the group normalization in continual normalization.- Return type:
None
- class renate.models.renate_module.RenateWrapper(model)[source]#
Bases:
RenateModule
A simple wrapper around a torch model.
If you are using a torch model with fixed hyperparameters, you can use this wrapper to expose it as a
RenateModule
. In this case, do _not_ use thefrom_state_dict
method but reinstantiate the model, wrap it, and callload_state_dict
. If a tuple or a dictionary of tensors is passed to theRenateWrapper
’s forward function, it is unpacked before passing it to the torch model’s forward function.Example:
my_torch_model = torch.nn.Linear(28*28, 10) # Instantiate your torch model. model = RenateWrapper(my_torch_model) state_dict = torch.load("my_state_dict.pt") model.load_state_dict(state_dict)
- Parameters:
model¶ (
Module
) – The torch model to be wrapped.
- forward(x, task_id=None)[source]#
Performs a forward pass on the inputs and returns the predictions.
This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.
- Parameters:
x¶ (
Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]],Dict
[str
,Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.task_id¶ (
Optional
[str
]) – The identifier of the task for which predictions are made.
- Return type:
Tensor
- Returns:
The model’s predictions.