renate.updaters.learner_components.component module#
- class renate.updaters.learner_components.component.Component(weight=0, sample_new_memory_batch=False)[source]#
Bases:
ABC
The abstract class implementing a Component, usable in the BaseExperienceReplayLearner.
This is an abstract class from which each other component e.g. additional regularising loss or a module updater should inherit from. The components should be a modular and independent to an extent where they can be composed together in an ordered list to be deployed in the BaseExperienceReplayLearner.
- Parameters:
- loss(outputs_memory, batch_memory, intermediate_representation_memory)[source]#
Computes some user-defined loss which is added to the main training loss in the training step.
- Parameters:
outputs_memory¶ (
Tensor
) – The outputs of the model with respect to memory data (batch_memory).batch_memory¶ (
Tuple
[Tuple
[Union
[Tensor
,Tuple
[Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]],Dict
[str
,Union
[Tensor
,Tuple
[NestedTensors],Dict
[str
, NestedTensors]]]],Tensor
],Dict
[str
,Tensor
]]) – The batch of data sampled from the memory buffer, including the meta data.intermediate_representation_memory¶ (
Optional
[List
[Tensor
]]) – Intermediate feature representations of the network upon passing the input through the network.
- Return type:
Tensor
- on_train_start(model)[source]#
Updates the model parameters.
- Parameters:
model¶ (
RenateModule
) – The model used for training.- Return type:
None
- on_train_batch_end(model)[source]#
Internally records a training and optimizer step in the component.
- Parameters:
model¶ (
RenateModule
) – The model used for training.- Return type:
None