renate.benchmark.models.base module#

class renate.benchmark.models.base.RenateBenchmarkingModule(embedding_size, num_outputs, constructor_arguments, prediction_strategy=None, add_icarl_class_means=True)[source]#

Bases: RenateModule, ABC

Base class for all models provided by Renate.

This class ensures that each models works with all ModelUpdaters when using the benchmarking feature of Renate. New models can extend this class or alternatively extend the RenateModule and make sure they are compatible with the considered ModelUpdater.

Parameters:
  • embedding_size (int) – Representation size of the model after the backbone.

  • num_outputs (int) – The number of outputs of the model.

  • constructor_arguments (dict) – Arguments needed to instantiate the model.

  • prediction_strategy (Optional[PredictionStrategy]) – By default a forward pass through the model. Some ModelUpdater must be combined with specific prediction strategies to work as intended.

  • add_icarl_class_means (bool) – Specific parameters for iCaRL. Can be set to False if any other ModelUpdater is used.

forward(x, task_id='default_task')[source]#

Performs a forward pass on the inputs and returns the predictions.

This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.

Parameters:
  • x (Tensor) – Input(s) to the model. Can be a single tensor, a tuple of tensor, or a dictionary mapping strings to tensors.

  • task_id (str) – The identifier of the task for which predictions are made.

Return type:

Tensor

Returns:

The model’s predictions.

get_backbone(task_id='default_task')[source]#

Returns the model without the prediction head.

Return type:

Module

get_predictor(task_id='default_task')[source]#

Returns the model without the backbone.

Return type:

Module

get_params(task_id='default_task')[source]#

Returns the list of parameters for the core model and a specific task_id.

Return type:

List[Parameter]

get_extra_state(encode=True)[source]#

Get the constructor_arguments and task ids necessary to reconstruct the model.

Encode converts the state into a torch tensor so that Deepspeed serialization works. We don’t encode any of the super() calls, but encode only the final dict.

Return type:

Any

set_extra_state(state, decode=True)[source]#

Extract the content of the _extra_state and set the related values in the module.

decode flag is to decode the tensor of pkl bytes.

training: bool#