renate.models.renate_module module#

class renate.models.renate_module.RenateModule(constructor_arguments)[source]#

Bases: Module, ABC

A class for torch models with some additional functionality for continual learning.

RenateModule derives from torch.nn.Module and provides some additional functionality relevant to continual learning. In particular, this concerns saving and reloading the model when model hyperparameters (which might affect the architecture) change during hyperparameter optimization. There is also functionality to retrieve internal-layer representations for use in replay-based CL methods.

When implementing a subclass of RenateModule, make sure to call the base class’ constructor and provide your model’s constructor arguments. Besides that, you can define a RenateModule just like torch.nn.Module.

Example:

class MyMNISTMLP(RenateModule):

def __init__(self, num_hidden: int):
    super().__init__(
        constructor_arguments={"num_hidden": num_hidden}
        loss_fn=torch.nn.CrossEntropyLoss()
    )
    self._fc1 = torch.nn.Linear(28*28, num_hidden)
    self._fc2 = torch.nn.Linear(num_hidden, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self._fc1(x)
    x = torch.nn.functional.relu(x)
    return self._fc2(x)

The state of a RenateModule can be retrieved via the RenateModule.state_dict() method, just as in torch.nn.Module. When reloading a RenateModule from a stored state dict, use RenateModule.from_state_dict. It wil automatically recover the hyperparameters and reinstantiate your model accordingly.

Note: Some methods of RenateModule accept an optional task_id argument. This is in anticipation of future methods for continual learning scenarios where task identifiers are provided. It is currently not used.

Parameters:

constructor_arguments (dict) – Arguments needed to instantiate the model.

classmethod from_state_dict(state_dict)[source]#

Load the model from a state dict.

Parameters:

state_dict – The state dict of the model. This method works under the assumption that this has been created by RenateModule.state_dict().

get_extra_state(encode=True)[source]#

Get the constructor_arguments, and task ids necessary to reconstruct the model.

Return type:

Any

set_extra_state(state, decode=True)[source]#

Extract the content of the _extra_state and set the related values in the module.

abstract forward(x, task_id=None)[source]#

Performs a forward pass on the inputs and returns the predictions.

This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.

Parameters:
  • x (Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.

  • task_id (Optional[str]) – The identifier of the task for which predictions are made.

Return type:

Tensor

Returns:

The model’s predictions.

get_params(task_id=None)[source]#

User-facing function which returns the list of parameters.

If a task_id is given, this should return only parameters used for the specific task.

Parameters:

task_id (Optional[str]) – The task id for which we want to retrieve parameters.

Return type:

List[Parameter]

add_task_params(task_id=None)[source]#

Adds new parameters, associated to a specific task, to the model.

This function should not be overwritten; use _add_task_params instead.

Parameters:

task_id (Optional[str]) – The task id for which the new parameters are added.

Return type:

None

get_logits(x, task_id=None)[source]#

Returns the logits for a given pair of input and task id.

By default, this method returns the output of the forward pass. This may be overwritten with custom behavior, if necessary.

Parameters:
  • x (Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.

  • task_id (Optional[str]) – The task id.

Return type:

Tensor

get_intermediate_representation()[source]#

Returns the cached intermediate representation.

Return type:

List[Tensor]

replace_batch_norm_with_continual_norm(num_groups=32)[source]#

Replaces every occurence of batch normalization with continual normalization.

Pham, Q., Liu, C., & Hoi, S. (2022). Continual normalization: Rethinking batch normalization for online continual learning. arXiv preprint arXiv:2203.16102.

Parameters:

num_groups (int) – Number of groups when considering the group normalization in continual normalization.

Return type:

None

register_intermediate_representation_caching_hook(module)[source]#

Add a hook to cache intermediate representations during training.

Store the reference to the hook to enable its removal.

Parameters:

module (Module) – The module to be hooked.

Return type:

None

deregister_hooks()[source]#

Remove all the hooks that were registered.

Return type:

None

reset_intermediate_representation_cache()[source]#

Resets the intermediate representation cache.

Return type:

None

training: bool#
class renate.models.renate_module.RenateWrapper(model)[source]#

Bases: RenateModule

A simple wrapper around a torch model.

If you are using a torch model with fixed hyperparameters, you can use this wrapper to expose it as a RenateModule. In this case, do _not_ use the from_state_dict method but reinstantiate the model, wrap it, and call load_state_dict. If a tuple or a dictionary of tensors is passed to the RenateWrapper’s forward function, it is unpacked before passing it to the torch model’s forward function.

Example:

my_torch_model = torch.nn.Linear(28*28, 10)  # Instantiate your torch model.
model = RenateWrapper(my_torch_model)
state_dict = torch.load("my_state_dict.pt")
model.load_state_dict(state_dict)
Parameters:

model (Module) – The torch model to be wrapped.

forward(x, task_id=None)[source]#

Performs a forward pass on the inputs and returns the predictions.

This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.

Parameters:
  • x (Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]]) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.

  • task_id (Optional[str]) – The identifier of the task for which predictions are made.

Return type:

Tensor

Returns:

The model’s predictions.

classmethod from_state_dict(state_dict)[source]#

Load the model from a state dict.

Parameters:

state_dict – The state dict of the model. This method works under the assumption that this has been created by RenateModule.state_dict().

training: bool#