renate.updaters.learner_components.losses module#

class renate.updaters.learner_components.losses.WeightedLossComponent(weight=0, sample_new_memory_batch=False)[source]#

Bases: Component, ABC

The abstract class implementing a weighted loss function.

This is an abstract class from which each other loss should inherit from.

loss(outputs_memory, batch_memory, intermediate_representation_memory)[source]#

Computes some user-defined loss which is added to the main training loss in the training step.

Parameters:
  • outputs_memory (Tensor) – The outputs of the model with respect to memory data (batch_memory).

  • batch_memory (Tuple[Tuple[Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]], Tensor], Dict[str, Tensor]]) – The batch of data sampled from the memory buffer, including the meta data.

  • intermediate_representation_memory (Optional[List[Tensor]]) – Intermediate feature representations of the network upon passing the input through the network.

Return type:

Tensor

class renate.updaters.learner_components.losses.WeightedCustomLossComponent(loss_fn, weight, sample_new_memory_batch)[source]#

Bases: WeightedLossComponent

Adds a (weighted) user-provided custom loss contribution.

Parameters:
  • loss_fn (Callable) – The loss function to apply.

  • weight (float) – A scaling coefficient which should scale the loss which gets returned.

  • sample_new_memory_batch (bool) – Whether a new batch of data should be sampled from the memory buffer when the loss is calculated.

class renate.updaters.learner_components.losses.WeightedMeanSquaredErrorLossComponent(weight=0, sample_new_memory_batch=False)[source]#

Bases: WeightedLossComponent

Mean squared error between the current and previous logits computed with respect to the memory sample.

class renate.updaters.learner_components.losses.WeightedPooledOutputDistillationLossComponent(weight, sample_new_memory_batch, distillation_type='spatial', normalize=True)[source]#

Bases: WeightedLossComponent

Pooled output feature distillation with respect to intermediate network features.

As described in: Douillard, Arthur, et al. “Podnet: Pooled outputs distillation for small-tasks incremental learning.” European Conference on Computer Vision. Springer, Cham, 2020.

Given the intermediate representations collected at different parts of the network, minimise their Euclidean distance with respect to the cached representation. There are different distillation_type`s trading-off plasticity and stability of the resultant representations. `normalize enables the user to normalize the resultant feature representations to ensure that they are less affected by their magnitude.

Parameters:
  • weight (float) – Scaling coefficient which scales the loss with respect to all intermediate representations.

  • sample_new_memory_batch (bool) – Whether a new batch of data should be sampled from the memory buffer when the loss is calculated.

  • distillation_type (str) – Which distillation type to apply with respect to all intermediate representations.

  • normalize (bool) – Whether to normalize both the current and cached features before computing the Frobenius norm.

class renate.updaters.learner_components.losses.WeightedCLSLossComponent(weight, sample_new_memory_batch, model, stable_model_update_weight, plastic_model_update_weight, stable_model_update_probability, plastic_model_update_probability)[source]#

Bases: WeightedLossComponent

Complementary Learning Systems Based Experience Replay.

Arani, Elahe, Fahad Sarfraz, and Bahram Zonooz. “Learning fast, learning slow: A general continual learning method based on complementary learning system.” arXiv preprint arXiv:2201.12604 (2022).

The implementation follows the Algorithm 1 in the respective paper. The complete Learner implementing this loss, is the CLSExperienceReplayLearner.

Parameters:
  • weight (float) – A scaling coefficient which should scale the loss which gets returned.

  • sample_new_memory_batch (bool) – Whether a new batch of data should be sampled from the memory buffer when the loss is calculated.

  • model (RenateModule) – The model that is being trained.

  • stable_model_update_weight (float) – The weight used in the update of the stable model.

  • plastic_model_update_weight (float) – The weight used in the update of the plastic model.

  • stable_model_update_probability (float) – The probability of updating the stable model at each training step.

  • plastic_model_update_probability (float) – The probability of updating the plastic model at each training step.

on_train_batch_end(model)[source]#

Updates the model copies with the current weights, given the specified probabilities of update, and increments iteration counter.

Return type:

None

on_load_checkpoint(checkpoint)[source]#

Load relevant information from checkpoint.

Return type:

None

on_save_checkpoint(checkpoint)[source]#

Add plastic and stable model to checkpoint.

Return type:

None