renate.benchmark.models.base module#
- class renate.benchmark.models.base.RenateBenchmarkingModule(embedding_size, num_outputs, constructor_arguments, prediction_strategy=None, add_icarl_class_means=True)[source]#
Bases:
RenateModule
,ABC
Base class for all models provided by Renate.
This class ensures that each models works with all ModelUpdaters when using the benchmarking feature of Renate. New models can extend this class or alternatively extend the RenateModule and make sure they are compatible with the considered ModelUpdater.
- Parameters:
embedding_size¶ (
int
) – Representation size of the model after the backbone.num_outputs¶ (
int
) – The number of outputs of the model.constructor_arguments¶ (
dict
) – Arguments needed to instantiate the model.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – By default a forward pass through the model. Some ModelUpdater must be combined with specific prediction strategies to work as intended.add_icarl_class_means¶ (
bool
) – Specific parameters for iCaRL. Can be set toFalse
if any other ModelUpdater is used.
- forward(x, task_id='default_task')[source]#
Performs a forward pass on the inputs and returns the predictions.
This method accepts a task ID, which may be provided by some continual learning scenarios. As an example, the task id may be used to switch between multiple output heads.
- get_backbone(task_id='default_task')[source]#
Returns the model without the prediction head.
- Return type:
Module
- get_predictor(task_id='default_task')[source]#
Returns the model without the backbone.
- Return type:
Module
- get_params(task_id='default_task')[source]#
Returns the list of parameters for the core model and a specific
task_id
.- Return type:
List
[Parameter
]