Source code for renate.updaters.learner_components.losses

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F

from renate.models import RenateModule
from renate.types import NestedTensors
from renate.updaters.learner_components.component import Component


[docs] class WeightedLossComponent(Component, ABC): """The abstract class implementing a weighted loss function. This is an abstract class from which each other loss should inherit from. """ def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" super()._verify_attributes() assert self.weight >= 0, "Weight must be larger than 0."
[docs] def loss( self, outputs_memory: torch.Tensor, batch_memory: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]], intermediate_representation_memory: Optional[List[torch.Tensor]], ) -> torch.Tensor: if self.weight == 0: return torch.tensor(0.0) return self._loss( outputs_memory=outputs_memory, batch_memory=batch_memory, intermediate_representation_memory=intermediate_representation_memory, )
@abstractmethod def _loss( self, outputs_memory: torch.Tensor, batch_memory: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]], intermediate_representation_memory: Optional[List[torch.Tensor]], ) -> torch.Tensor: pass
[docs] class WeightedCustomLossComponent(WeightedLossComponent): """Adds a (weighted) user-provided custom loss contribution. Args: loss_fn: The loss function to apply. weight: A scaling coefficient which should scale the loss which gets returned. sample_new_memory_batch: Whether a new batch of data should be sampled from the memory buffer when the loss is calculated. """ def __init__(self, loss_fn: Callable, weight: float, sample_new_memory_batch: bool) -> None: super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch) self._loss_fn = loss_fn def _loss( self, outputs_memory: torch.Tensor, batch_memory: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]], intermediate_representation_memory: Optional[List[torch.Tensor]], ) -> torch.Tensor: """Returns user-provided loss evaluated on memory batch.""" (_, targets_memory), _ = batch_memory return self.weight * self._loss_fn(outputs_memory, targets_memory)
[docs] class WeightedMeanSquaredErrorLossComponent(WeightedLossComponent): """Mean squared error between the current and previous logits computed with respect to the memory sample. """ def _loss( self, outputs_memory: torch.Tensor, batch_memory: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]], intermediate_representation_memory: Optional[List[torch.Tensor]], ) -> torch.Tensor: """Mean-squared error between current and previous logits on memory.""" logits = outputs_memory _, meta_data = batch_memory previous_logits = meta_data["outputs"] return self.weight * F.mse_loss(logits, previous_logits, reduction="mean")
[docs] class WeightedPooledOutputDistillationLossComponent(WeightedLossComponent): """Pooled output feature distillation with respect to intermediate network features. As described in: Douillard, Arthur, et al. "Podnet: Pooled outputs distillation for small-tasks incremental learning." European Conference on Computer Vision. Springer, Cham, 2020. Given the intermediate representations collected at different parts of the network, minimise their Euclidean distance with respect to the cached representation. There are different `distillation_type`s trading-off plasticity and stability of the resultant representations. `normalize` enables the user to normalize the resultant feature representations to ensure that they are less affected by their magnitude. Args: weight: Scaling coefficient which scales the loss with respect to all intermediate representations. sample_new_memory_batch: Whether a new batch of data should be sampled from the memory buffer when the loss is calculated. distillation_type: Which distillation type to apply with respect to all intermediate representations. normalize: Whether to normalize both the current and cached features before computing the Frobenius norm. """ def __init__( self, weight: float, sample_new_memory_batch: bool, distillation_type: str = "spatial", normalize: bool = True, ) -> None: self._distillation_type = distillation_type super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch) self._normalize = normalize def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" super()._verify_attributes() if self._distillation_type not in ["pixel", "channel", "width", "height", "gap", "spatial"]: raise ValueError(f"Invalid distillation type: {self._distillation_type}") def _sum_reshape(self, x: torch.Tensor, dim: int) -> torch.Tensor: """Sum the tensor according to specific dimension and reshape.""" batch_size = x.shape[0] return x.sum(dim=dim).reshape(batch_size, -1) def _pod(self, features: torch.Tensor, features_memory: torch.Tensor) -> torch.Tensor: """Pooled output distillation with respect to intermediate and cached intermediate features. Args: features: Current intermediate features. features_memory: Cached intermediate features. """ if features.shape != features_memory.shape: raise ValueError( "The shape of the features and the cached features should be the same: " f"{features.shape}, and: {features_memory.shape}" ) features = features.pow(2) features_memory = features_memory.pow(2) if self._distillation_type == "channel": features, features_memory = self._sum_reshape(features, 1), self._sum_reshape( features_memory, 1 ) elif self._distillation_type == "width": features, features_memory = self._sum_reshape(features, 2), self._sum_reshape( features_memory, 2 ) elif self._distillation_type == "height": features, features_memory = self._sum_reshape(features, 3), self._sum_reshape( features_memory, 3 ) elif self._distillation_type == "gap": features = F.adaptive_avg_pool2d(features, (1, 1))[..., 0, 0] features_memory = F.adaptive_avg_pool2d(features_memory, (1, 1))[..., 0, 0] elif self._distillation_type == "spatial": features_h, features_memory_h = self._sum_reshape(features, 3), self._sum_reshape( features_memory, 3 ) features_w, features_memory_w = self._sum_reshape(features, 2), self._sum_reshape( features_memory, 2 ) features = torch.cat([features_h, features_w], dim=-1) features_memory = torch.cat([features_memory_h, features_memory_w], dim=-1) if self._normalize: features = F.normalize(features, dim=1, p=2) features_memory = F.normalize(features_memory, dim=1, p=2) return torch.frobenius_norm(features - features_memory, dim=-1).mean(dim=0) def _loss( self, outputs_memory: torch.Tensor, batch_memory: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]], intermediate_representation_memory: List[torch.Tensor], ) -> torch.Tensor: """Compute the pooled output with respect to current and cached intermediate outputs from memory. """ loss = torch.tensor(0.0) _, meta_data = batch_memory for n in range(len(intermediate_representation_memory)): features = intermediate_representation_memory[n] features_memory = meta_data[f"intermediate_representation_{n}"] loss += self._pod(features, features_memory) return (self.weight * loss) / len(intermediate_representation_memory)
[docs] class WeightedCLSLossComponent(WeightedLossComponent): """Complementary Learning Systems Based Experience Replay. Arani, Elahe, Fahad Sarfraz, and Bahram Zonooz. "Learning fast, learning slow: A general continual learning method based on complementary learning system." arXiv preprint arXiv:2201.12604 (2022). The implementation follows the Algorithm 1 in the respective paper. The complete `Learner` implementing this loss, is the `CLSExperienceReplayLearner`. Args: weight: A scaling coefficient which should scale the loss which gets returned. sample_new_memory_batch: Whether a new batch of data should be sampled from the memory buffer when the loss is calculated. model: The model that is being trained. stable_model_update_weight: The weight used in the update of the stable model. plastic_model_update_weight: The weight used in the update of the plastic model. stable_model_update_probability: The probability of updating the stable model at each training step. plastic_model_update_probability: The probability of updating the plastic model at each training step. """ def __init__( self, weight: float, sample_new_memory_batch: bool, model: RenateModule, stable_model_update_weight: float, plastic_model_update_weight: float, stable_model_update_probability: float, plastic_model_update_probability: float, ) -> None: self._stable_model_update_weight = stable_model_update_weight self._stable_model_update_weight = stable_model_update_weight self._plastic_model_update_weight = plastic_model_update_weight self._stable_model_update_probability = stable_model_update_probability self._plastic_model_update_probability = plastic_model_update_probability self._iteration = 0 super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch) self._plastic_model: RenateModule = copy.deepcopy(model) self._stable_model: RenateModule = copy.deepcopy(model) self._plastic_model.deregister_hooks() self._stable_model.deregister_hooks() def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" super()._verify_attributes() assert 0.0 <= self._stable_model_update_weight assert 0.0 <= self._plastic_model_update_weight assert 0.0 <= self._stable_model_update_probability <= 1.0 assert 0.0 <= self._plastic_model_update_probability <= 1.0 assert self._plastic_model_update_probability > self._stable_model_update_probability assert self._plastic_model_update_weight <= self._stable_model_update_weight def _loss( self, outputs_memory: torch.Tensor, batch_memory: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]], intermediate_representation_memory: Optional[List[torch.Tensor]], ) -> torch.Tensor: """Computes the consistency loss with respect to averaged plastic and stable models.""" (inputs_memory, targets_memory), _ = batch_memory with torch.no_grad(): outputs_plastic = self._plastic_model(inputs_memory) outputs_stable = self._stable_model(inputs_memory) probs_plastic = F.softmax(outputs_plastic, dim=-1) probs_stable = F.softmax(outputs_stable, dim=-1) label_mask = F.one_hot(targets_memory, num_classes=outputs_stable.shape[-1]) > 0 idx = (probs_stable[label_mask] > probs_plastic[label_mask]).unsqueeze(1) outputs = torch.where(idx, outputs_stable, outputs_plastic) consistency_loss = F.mse_loss(outputs_memory, outputs.detach(), reduction="mean") return self.weight * consistency_loss @torch.no_grad() def _update_model_variables( self, model: RenateModule, original_model: RenateModule, weight: float ) -> None: """Performs exponential moving average on the stored model copies. Args: model: Whether the plastic or the stable model is updated. weight: The minimum weight used in the exponential moving average to update the model. """ alpha = min(1.0 - 1.0 / (self._iteration + 1), weight) for ema_p, p in zip(model.parameters(), original_model.parameters()): ema_p.data.mul_(alpha).add_(p.data, alpha=1 - alpha)
[docs] def on_train_batch_end(self, model: RenateModule) -> None: """Updates the model copies with the current weights, given the specified probabilities of update, and increments iteration counter.""" self._iteration += 1 if torch.rand(1) < self._plastic_model_update_probability: self._update_model_variables( self._plastic_model, model, self._plastic_model_update_weight ) if torch.rand(1) < self._stable_model_update_probability: self._update_model_variables( self._stable_model, model, self._stable_model_update_weight )
[docs] def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Load relevant information from checkpoint.""" super().on_load_checkpoint(checkpoint) self._plastic_model.load_state_dict(checkpoint["component-cls-plastic-model"]) self._stable_model.load_state_dict(checkpoint["component-cls-stable-model"]) self._iteration = checkpoint["component-cls-iteration"]
[docs] def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Add plastic and stable model to checkpoint.""" super().on_save_checkpoint(checkpoint) checkpoint["component-cls-plastic-model"] = self._plastic_model.state_dict() checkpoint["component-cls-stable-model"] = self._stable_model.state_dict() checkpoint["component-cls-iteration"] = self._iteration