Source code for renate.updaters.learner_components.component
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import abc
from typing import Any, Dict, List, Optional, Tuple
import torch
from renate.models import RenateModule
from renate.types import NestedTensors
[docs]
class Component(abc.ABC):
"""The abstract class implementing a Component, usable in the BaseExperienceReplayLearner.
This is an abstract class from which each other component e.g. additional
regularising loss or a module updater should inherit from.
The components should be a modular and independent to an extent where they can be composed
together in an ordered list to be deployed in the BaseExperienceReplayLearner.
Args:
weight: A scaling coefficient which should scale the loss which gets returned.
sample_new_memory_batch: Whether a new batch of data should be sampled from the memory
buffer when the loss is calculated.
"""
def __init__(self, weight: float = 0, sample_new_memory_batch: bool = False) -> None:
self.weight = weight
self.sample_new_memory_batch = sample_new_memory_batch
self._verify_attributes()
[docs]
def loss(
self,
outputs_memory: torch.Tensor,
batch_memory: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]],
intermediate_representation_memory: Optional[List[torch.Tensor]],
) -> torch.Tensor:
"""Computes some user-defined loss which is added to the main training loss in the training
step.
Args:
outputs_memory: The outputs of the model with respect to memory data (batch_memory).
batch_memory: The batch of data sampled from the memory buffer, including the meta data.
intermediate_representation_memory: Intermediate feature representations of the network
upon passing the input through the network.
"""
return torch.tensor(0.0)
[docs]
def on_train_start(self, model: RenateModule) -> None:
"""Updates the model parameters.
Args:
model: The model used for training.
"""
pass
[docs]
def on_train_batch_end(self, model: RenateModule) -> None:
"""Internally records a training and optimizer step in the component.
Args:
model: The model used for training.
"""
pass
def _verify_attributes(self) -> None:
"""Verify if attributes have valid values."""
pass
[docs]
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Load relevant information from checkpoint."""
pass
[docs]
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Add relevant information to checkpoint."""
pass