renate.benchmark.models.base module#
- class renate.benchmark.models.base.RenateBenchmarkingModule(embedding_size, num_outputs, constructor_arguments, prediction_strategy=None, add_icarl_class_means=True)[source]#
Bases:
RenateModule
,ABC
Base class for all models provided by Renate.
This class ensures that each models works with all ModelUpdaters when using the benchmarking feature of Renate. New models can extend this class or alternatively extend the RenateModule and make sure they are compatible with the considered ModelUpdater.
- Parameters:
embedding_size¶ (
int
) – Representation size of the model after the backbone.num_outputs¶ (
int
) – The number of outputs of the model.constructor_arguments¶ (
dict
) – Arguments needed to instantiate the model.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – By default a forward pass through the model. Some ModelUpdater must be combined with specific prediction strategies to work as intended.add_icarl_class_means¶ (
bool
) – Specific parameters for iCaRL. Can be set toFalse
if any other ModelUpdater is used.
- forward(x, task_id='default_task')[source]#
Performs a forward pass on the inputs and returns the predictions.
This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.
- get_backbone(task_id='default_task')[source]#
Returns the model without the prediction head.
- Return type:
Module
- get_predictor(task_id='default_task')[source]#
Returns the model without the backbone.
- Return type:
Module
- get_params(task_id='default_task')[source]#
Returns the list of parameters for the core model and a specific task_id.
- Return type:
List
[Parameter
]
- get_extra_state(encode=True)[source]#
Get the constructor_arguments and task ids necessary to reconstruct the model.
Encode converts the state into a torch tensor so that Deepspeed serialization works. We don’t encode any of the super() calls, but encode only the final dict.
- Return type:
Any
- set_extra_state(state, decode=True)[source]#
Extract the content of the
_extra_state
and set the related values in the module.decode flag is to decode the tensor of pkl bytes.
-
training:
bool
#