renate.updaters.learner_components.component module#

class renate.updaters.learner_components.component.Component(weight=0, sample_new_memory_batch=False)[source]#

Bases: ABC

The abstract class implementing a Component, usable in the BaseExperienceReplayLearner.

This is an abstract class from which each other component e.g. additional regularising loss or a module updater should inherit from. The components should be a modular and independent to an extent where they can be composed together in an ordered list to be deployed in the BaseExperienceReplayLearner.

Parameters:
  • weight (float) – A scaling coefficient which should scale the loss which gets returned.

  • sample_new_memory_batch (bool) – Whether a new batch of data should be sampled from the memory buffer when the loss is calculated.

loss(outputs_memory, batch_memory, intermediate_representation_memory)[source]#

Computes some user-defined loss which is added to the main training loss in the training step.

Parameters:
  • outputs_memory (Tensor) – The outputs of the model with respect to memory data (batch_memory).

  • batch_memory (Tuple[Tuple[Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]], Tensor], Dict[str, Tensor]]) – The batch of data sampled from the memory buffer, including the meta data.

  • intermediate_representation_memory (Optional[List[Tensor]]) – Intermediate feature representations of the network upon passing the input through the network.

Return type:

Tensor

on_train_start(model)[source]#

Updates the model parameters.

Parameters:

model (RenateModule) – The model used for training.

Return type:

None

on_train_batch_end(model)[source]#

Internally records a training and optimizer step in the component.

Parameters:

model (RenateModule) – The model used for training.

Return type:

None

on_load_checkpoint(checkpoint)[source]#

Load relevant information from checkpoint.

Return type:

None

on_save_checkpoint(checkpoint)[source]#

Add relevant information to checkpoint.

Return type:

None