renate.updaters.learner_components.component module#
- class renate.updaters.learner_components.component.Component(weight=0, sample_new_memory_batch=False)[source]#
Bases:
ABCThe abstract class implementing a Component, usable in the BaseExperienceReplayLearner.
This is an abstract class from which each other component e.g. additional regularising loss or a module updater should inherit from. The components should be a modular and independent to an extent where they can be composed together in an ordered list to be deployed in the BaseExperienceReplayLearner.
- Parameters:
- loss(outputs_memory, batch_memory, intermediate_representation_memory)[source]#
Computes some user-defined loss which is added to the main training loss in the training step.
- Parameters:
outputs_memory¶ (
Tensor) – The outputs of the model with respect to memory data (batch_memory).batch_memory¶ (
Tuple[Tuple[Union[Tensor,Tuple[Union[Tensor,Tuple[NestedTensors],Dict[str, NestedTensors]]],Dict[str,Union[Tensor,Tuple[NestedTensors],Dict[str, NestedTensors]]]],Tensor],Dict[str,Tensor]]) – The batch of data sampled from the memory buffer, including the meta data.intermediate_representation_memory¶ (
Optional[List[Tensor]]) – Intermediate feature representations of the network upon passing the input through the network.
- Return type:
Tensor
- on_train_start(model)[source]#
Updates the model parameters.
- Parameters:
model¶ (
RenateModule) – The model used for training.- Return type:
None
- on_train_batch_end(model)[source]#
Internally records a training and optimizer step in the component.
- Parameters:
model¶ (
RenateModule) – The model used for training.- Return type:
None