# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import abc
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torchmetrics
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset, Subset
from renate import defaults
from renate.data.datasets import _EnumeratedDataset, _TransformedDataset
from renate.memory.buffer import DataDict
from renate.models import RenateModule
from renate.types import NestedTensors
from renate.updaters.learner import ReplayLearner
from renate.updaters.learner_components.component import Component
from renate.updaters.learner_components.losses import (
WeightedCLSLossComponent,
WeightedCustomLossComponent,
WeightedMeanSquaredErrorLossComponent,
WeightedPooledOutputDistillationLossComponent,
)
from renate.updaters.learner_components.reinitialization import (
ShrinkAndPerturbReinitializationComponent,
)
from renate.updaters.model_updater import SingleTrainingLoopUpdater
from renate.utils.misc import maybe_populate_mask_and_ignore_logits
from renate.utils.pytorch import move_tensors_to_device
[docs]
class BaseExperienceReplayLearner(ReplayLearner, abc.ABC):
"""A base implementation of experience replay.
It is designed for the online CL setting, where only one pass over each new chunk of data is
allowed. The Learner maintains a Reservoir buffer. In the training step, it samples a batch of
data from the memory and appends it to the batch of current-task data. At the end of the
training step, the memory is updated.
Args:
components: An ordered dictionary of components that are part of the experience replay
learner.
loss_weight: A scalar weight factor for the base loss function to trade it off with other
loss functions added by `components`.
ema_memory_update_gamma: The gamma used for exponential moving average to update the meta
data with respect to the logits and intermediate representation, if there is some.
loss_normalization: Whether to normalize the loss by the weights of all the components.
"""
def __init__(
self,
components: Dict[str, Component],
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
**kwargs: Any,
) -> None:
self._components_names = list(components.keys())
super().__init__(**kwargs)
self._memory_loader: Optional[DataLoader] = None
self._components = components
self._loss_weight = loss_weight
self._ema_memory_update_gamma = ema_memory_update_gamma
self._use_loss_normalization = bool(loss_normalization)
def _create_metrics_collections(
self, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None
) -> None:
super()._create_metrics_collections(logged_metrics)
for name in self._components_names:
if name in self._loss_collections:
raise ValueError(
f"Component name {name} is already used as a loss name. Please pick a "
"different name."
)
self._loss_collections["train_losses"].update({name: torchmetrics.MeanMetric()})
[docs]
def on_model_update_start(
self,
train_dataset: Dataset,
val_dataset: Dataset,
train_dataset_collate_fn: Optional[Callable] = None,
val_dataset_collate_fn: Optional[Callable] = None,
task_id: Optional[str] = None,
) -> None:
"""Called before a model update starts."""
super().on_model_update_start(
train_dataset=train_dataset,
val_dataset=val_dataset,
train_dataset_collate_fn=train_dataset_collate_fn,
val_dataset_collate_fn=val_dataset_collate_fn,
task_id=task_id,
)
self._set_memory_loader()
[docs]
def train_dataloader(self) -> DataLoader:
train_dataset = _EnumeratedDataset(
_TransformedDataset(
self._train_dataset,
transform=self._train_transform,
target_transform=self._train_target_transform,
)
)
return DataLoader(
train_dataset,
batch_size=self._batch_size,
shuffle=True,
generator=self._rng,
pin_memory=True,
collate_fn=self._train_collate_fn,
)
[docs]
def on_train_start(self) -> None:
"""PyTorch Lightning function to be run at the start of the training."""
super().on_train_start()
for component in self._components.values():
component.on_train_start(model=self._model)
[docs]
def training_step(
self, batch: Tuple[torch.Tensor, Tuple[NestedTensors, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
"""PyTorch Lightning function to return the training loss."""
idx, (inputs, targets) = batch
step_output = super().training_step(batch=(inputs, targets), batch_idx=batch_idx)
step_output["train_data_idx"] = idx
step_output["loss"] *= self._loss_weight
batch_memory: Optional[torch.Tensor] = None
metadata_memory: Optional[torch.Tensor] = None
outputs_memory: Optional[torch.Tensor] = None
intermediate_representation_memory: Optional[List[torch.Tensor]] = None
loss_normalization = self._loss_weight
if self._memory_loader is not None:
for name, component in self._components.items():
memory_sampled = False
if component.sample_new_memory_batch or batch_memory is None:
batch_memory = self._sample_from_buffer(device=step_output["loss"].device)
(inputs_memory, _), metadata_memory = batch_memory
outputs_memory = self(inputs_memory)
outputs_memory, self._class_mask = maybe_populate_mask_and_ignore_logits(
self._mask_unused_classes,
self._class_mask,
self._classes_in_current_task,
outputs_memory,
)
intermediate_representation_memory = (
self._model.get_intermediate_representation()
)
self._model.reset_intermediate_representation_cache()
memory_sampled = True
component_loss = component.loss(
outputs_memory=outputs_memory,
batch_memory=batch_memory,
intermediate_representation_memory=intermediate_representation_memory,
).mean()
self._loss_collections["train_losses"][name](component_loss)
step_output["loss"] += component_loss
loss_normalization += component.weight
if memory_sampled and self._ema_memory_update_gamma < 1.0:
mem_idx = metadata_memory["idx"].cpu()
self._memory_buffer.metadata["outputs"][mem_idx] = self._memory_buffer.metadata[
"outputs"
][
mem_idx
].cpu() * self._ema_memory_update_gamma + outputs_memory.detach().cpu() * (
1.0 - self._ema_memory_update_gamma
)
if self._use_loss_normalization:
step_output["loss"] /= loss_normalization
return step_output
def _sample_from_buffer(self, device: torch.device) -> Optional[Tuple[NestedTensors, DataDict]]:
"""Function to sample from the buffer, if buffer is populated."""
if self._memory_loader is not None and len(self._memory_buffer) >= self._memory_batch_size:
memory_batch = next(iter(self._memory_loader))
return move_tensors_to_device(memory_batch, device)
else:
return None
[docs]
def training_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT:
"""PyTorch Lightning function to perform after the training step."""
super().training_step_end(step_output)
self._update_memory_buffer(step_output)
return step_output
def _update_memory_buffer(self, step_output: STEP_OUTPUT) -> None:
outputs = step_output["outputs"]
metadata = {"outputs": outputs.detach().cpu()}
for i, intermediate_representation in enumerate(step_output["intermediate_representation"]):
metadata[
f"intermediate_representation_{i}"
] = intermediate_representation.detach().cpu()
# Some datasets have problems using tensors as subset indices, convert to list of ints.
train_data_idx = [int(idx) for idx in step_output["train_data_idx"]]
dataset = Subset(self._train_dataset, train_data_idx)
self._memory_buffer.update(dataset, metadata)
self._set_memory_loader()
def _set_memory_loader(self) -> None:
"""Create a memory loader from a memory buffer."""
if self._memory_loader is None and len(self._memory_buffer) >= self._memory_batch_size:
self._memory_loader = DataLoader(
dataset=self._memory_buffer,
batch_size=self._memory_batch_size,
drop_last=True,
shuffle=True,
generator=self._rng,
pin_memory=True,
collate_fn=self._train_collate_fn,
)
[docs]
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
"""PyTorch Lightning function to perform after the training and optimizer step."""
super().on_train_batch_end(outputs, batch, batch_idx)
for component in self._components.values():
component.on_train_batch_end(model=self._model)
[docs]
@abc.abstractmethod
def components(self, **kwargs) -> Dict[str, Component]:
"""Returns the components of the learner.
This is a user-defined function that should return a dictionary of components.
"""
[docs]
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Load states of components."""
super().on_load_checkpoint(checkpoint)
for component in self._components.values():
component.on_load_checkpoint(checkpoint)
[docs]
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Save states of components."""
super().on_save_checkpoint(checkpoint)
for component in self._components.values():
component.on_save_checkpoint(checkpoint)
[docs]
class ExperienceReplayLearner(BaseExperienceReplayLearner):
"""This is the version of experience replay proposed in
Chaudhry, Arslan, et al. "On tiny episodic memories in continual learning."
arXiv preprint arXiv:1902.10486 (2019).
Args:
alpha: The weight of the cross-entropy loss component applied to the memory samples.
"""
def __init__(self, alpha: float = defaults.ER_ALPHA, **kwargs) -> None:
components = self.components(loss_fn=kwargs["loss_fn"], alpha=alpha)
super().__init__(components=components, **kwargs)
[docs]
def components(
self, loss_fn: Optional[torch.nn.Module] = None, alpha: float = defaults.ER_ALPHA
) -> Dict[str, Component]:
return {
"memory_loss": WeightedCustomLossComponent(
loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True
)
}
[docs]
class DarkExperienceReplayLearner(ExperienceReplayLearner):
"""A Learner that implements Dark Experience Replay.
Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara:
Dark Experience for General Continual Learning: a Strong, Simple Baseline. NeurIPS 2020
Args:
alpha: The weight of the mean squared error loss component between memorised logits and the
current logits on the memory data.
beta: The weight of the cross-entropy loss component between memorised targets and the
current logits on the memory data.
"""
def __init__(
self, alpha: float = defaults.DER_ALPHA, beta: float = defaults.DER_BETA, **kwargs
) -> None:
super().__init__(alpha=beta, **kwargs)
self._components = self.components(loss_fn=kwargs["loss_fn"], alpha=alpha, beta=beta)
[docs]
def components(
self,
loss_fn: Optional[torch.nn.Module] = None,
alpha: float = defaults.DER_ALPHA,
beta: float = defaults.DER_BETA,
) -> Dict[str, Component]:
components = super().components(loss_fn=loss_fn, alpha=beta)
components.update(
{
"mse_loss": WeightedMeanSquaredErrorLossComponent(
weight=alpha, sample_new_memory_batch=False
)
}
)
return components
[docs]
class PooledOutputDistillationExperienceReplayLearner(BaseExperienceReplayLearner):
"""A Learner that implements Pooled Output Distillation.
Douillard, Arthur, et al. "Podnet: Pooled outputs distillation for small-tasks incremental
learning."
European Conference on Computer Vision. Springer, Cham, 2020.
Args:
alpha: Scaling value which scales the loss with respect to all intermediate representations.
distillation_type: Which distillation type to apply with respect to the intermediate
representation.
normalize: Whether to normalize both the current and cached features before computing the
Frobenius norm.
"""
def __init__(
self,
alpha: float = defaults.POD_ALPHA,
distillation_type: str = defaults.POD_DISTILLATION_TYPE,
normalize: bool = defaults.POD_NORMALIZE,
**kwargs,
) -> None:
components = self.components(
alpha=alpha, distillation_type=distillation_type, normalize=normalize
)
super().__init__(components=components, **kwargs)
[docs]
def components(
self,
alpha: float = defaults.POD_ALPHA,
distillation_type: str = defaults.POD_DISTILLATION_TYPE,
normalize: bool = defaults.POD_NORMALIZE,
) -> Dict[str, Component]:
return {
"pod_loss": WeightedPooledOutputDistillationLossComponent(
weight=alpha,
sample_new_memory_batch=True,
distillation_type=distillation_type,
normalize=normalize,
)
}
[docs]
class CLSExperienceReplayLearner(BaseExperienceReplayLearner):
"""A learner that implements a Complementary Learning Systems Based Experience Replay.
Arani, Elahe, Fahad Sarfraz, and Bahram Zonooz.
"Learning fast, learning slow: A general continual learning method based on complementary
learning system."
arXiv preprint arXiv:2201.12604 (2022).
Args:
alpha: Scaling value for the cross-entropy loss.
beta: Scaling value for the consistency loss.
stable_model_update_weight: The starting weight for the exponential moving average to update
the stable model copy.
plastic_model_update_weight: The starting weight for the exponential moving average to
update the plastic model copy.
stable_model_update_probability: The probability to update the stable model copy.
plastic_model_update_probability: The probability to update the plastic model copy.
"""
def __init__(
self,
alpha: float = defaults.CLS_ALPHA,
beta: float = defaults.CLS_BETA,
stable_model_update_weight: float = defaults.CLS_STABLE_MODEL_UPDATE_WEIGHT,
plastic_model_update_weight: float = defaults.CLS_PLASTIC_MODEL_UPDATE_WEIGHT,
stable_model_update_probability: float = defaults.CLS_STABLE_MODEL_UPDATE_PROBABILITY,
plastic_model_update_probability: float = defaults.CLS_PLASTIC_MODEL_UPDATE_PROBABILITY,
**kwargs,
):
components = self.components(
model=kwargs["model"],
loss_fn=kwargs["loss_fn"],
alpha=alpha,
beta=beta,
stable_model_update_weight=stable_model_update_weight,
plastic_model_update_weight=plastic_model_update_weight,
stable_model_update_probability=stable_model_update_probability,
plastic_model_update_probability=plastic_model_update_probability,
)
super().__init__(components=components, **kwargs)
[docs]
def components(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
alpha: float = defaults.CLS_ALPHA,
beta: float = defaults.CLS_BETA,
plastic_model_update_weight: float = defaults.CLS_PLASTIC_MODEL_UPDATE_WEIGHT,
stable_model_update_weight: float = defaults.CLS_STABLE_MODEL_UPDATE_WEIGHT,
plastic_model_update_probability: float = defaults.CLS_PLASTIC_MODEL_UPDATE_PROBABILITY,
stable_model_update_probability: float = defaults.CLS_STABLE_MODEL_UPDATE_PROBABILITY,
) -> Dict[str, Component]:
return {
"memory_loss": WeightedCustomLossComponent(
loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True
),
"cls_loss": WeightedCLSLossComponent(
weight=beta,
sample_new_memory_batch=False,
model=model,
plastic_model_update_weight=plastic_model_update_weight,
stable_model_update_weight=stable_model_update_weight,
plastic_model_update_probability=plastic_model_update_probability,
stable_model_update_probability=stable_model_update_probability,
),
}
[docs]
class SuperExperienceReplayLearner(BaseExperienceReplayLearner):
"""A learner that implements a selected combination of methods.
Args:
der_alpha: The weight of the mean squared error loss component between memorised logits and
the current logits on the memory data.
der_beta: The weight of the cross-entropy loss component between memorised targets and the
current logits on the memory data.
sp_shrink_factor: Shrinking value applied with respect to shrink and perturbation.
sp_sigma: Standard deviation applied with respect to shrink and perturbation.
cls_alpha: Scaling value for the consistency loss added to the base cross-entropy loss.
cls_stable_model_update_weight: The starting weight for the exponential moving average to
update the stable model copy.
cls_plastic_model_update_weight: The starting weight for the exponential moving average to
update the plastic model copy.
cls_stable_model_update_probability: The probability to update the stable model copy.
cls_plastic_model_update_probability: The probability to update the plastic model copy.
pod_alpha: Scaling value which scales the loss with respect to all intermediate
representations.
pod_distillation_type: Which distillation type to apply with respect to the intermediate
representation.
pod_normalize: Whether to normalize both the current and cached features before computing
the Frobenius norm.
ema_memory_update_gamma: The gamma used for exponential moving average to update the meta
data with respect to the logits and intermediate representation, if there is some.
"""
def __init__(
self,
der_alpha: float = defaults.SER_DER_ALPHA,
der_beta: float = defaults.SER_DER_BETA,
sp_shrink_factor: float = defaults.SER_SP_SHRINK_FACTOR,
sp_sigma: float = defaults.SER_SP_SIGMA,
cls_alpha: float = defaults.SER_CLS_ALPHA,
cls_stable_model_update_weight: float = defaults.SER_CLS_STABLE_MODEL_UPDATE_WEIGHT,
cls_plastic_model_update_weight: float = defaults.SER_CLS_PLASTIC_MODEL_UPDATE_WEIGHT,
cls_stable_model_update_probability: float = defaults.SER_CLS_STABLE_MODEL_UPDATE_PROBABILITY, # noqa: E501
cls_plastic_model_update_probability: float = defaults.SER_CLS_PLASTIC_MODEL_UPDATE_PROBABILITY, # noqa: E501
pod_alpha: float = defaults.SER_POD_ALPHA,
pod_distillation_type: str = defaults.SER_POD_DISTILLATION_TYPE,
pod_normalize: bool = defaults.SER_POD_NORMALIZE,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
**kwargs,
) -> None:
components = self.components(
model=kwargs["model"],
loss_fn=kwargs["loss_fn"],
der_alpha=der_alpha,
der_beta=der_beta,
sp_shrink_factor=sp_shrink_factor,
sp_sigma=sp_sigma,
cls_alpha=cls_alpha,
cls_stable_model_update_weight=cls_stable_model_update_weight,
cls_plastic_model_update_weight=cls_plastic_model_update_weight,
cls_stable_model_update_probability=cls_stable_model_update_probability,
cls_plastic_model_update_probability=cls_plastic_model_update_probability,
pod_alpha=pod_alpha,
pod_distillation_type=pod_distillation_type,
pod_normalize=pod_normalize,
)
super().__init__(
components=components,
ema_memory_update_gamma=ema_memory_update_gamma,
**kwargs,
)
[docs]
def components(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
der_alpha: float = defaults.SER_DER_ALPHA,
der_beta: float = defaults.SER_DER_BETA,
sp_shrink_factor: float = defaults.SER_SP_SHRINK_FACTOR,
sp_sigma: float = defaults.SER_SP_SIGMA,
cls_alpha: float = defaults.SER_CLS_ALPHA,
cls_stable_model_update_weight: float = defaults.SER_CLS_STABLE_MODEL_UPDATE_WEIGHT,
cls_plastic_model_update_weight: float = defaults.SER_CLS_PLASTIC_MODEL_UPDATE_WEIGHT,
cls_stable_model_update_probability: float = defaults.SER_CLS_STABLE_MODEL_UPDATE_PROBABILITY, # noqa: E501
cls_plastic_model_update_probability: float = defaults.SER_CLS_PLASTIC_MODEL_UPDATE_PROBABILITY, # noqa: E501
pod_alpha: float = defaults.SER_POD_ALPHA,
pod_distillation_type: str = defaults.SER_POD_DISTILLATION_TYPE,
pod_normalize: bool = defaults.SER_POD_NORMALIZE,
) -> Dict[str, Component]:
return {
"mse_loss": WeightedMeanSquaredErrorLossComponent(
weight=der_alpha, sample_new_memory_batch=True
),
"memory_loss": WeightedCustomLossComponent(
loss_fn=loss_fn, weight=der_beta, sample_new_memory_batch=True
),
"cls_loss": WeightedCLSLossComponent(
weight=cls_alpha,
sample_new_memory_batch=False,
model=model,
stable_model_update_weight=cls_stable_model_update_weight,
plastic_model_update_weight=cls_plastic_model_update_weight,
stable_model_update_probability=cls_stable_model_update_probability,
plastic_model_update_probability=cls_plastic_model_update_probability,
),
"shrink_perturb": ShrinkAndPerturbReinitializationComponent(
shrink_factor=sp_shrink_factor, sigma=sp_sigma
),
"pod_loss": WeightedPooledOutputDistillationLossComponent(
weight=pod_alpha,
sample_new_memory_batch=True,
distillation_type=pod_distillation_type,
normalize=pod_normalize,
),
}
[docs]
class ExperienceReplayModelUpdater(SingleTrainingLoopUpdater):
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
alpha: float = defaults.ER_ALPHA,
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
input_state_folder: Optional[str] = None,
output_state_folder: Optional[str] = None,
max_epochs: int = defaults.MAX_EPOCHS,
train_transform: Optional[Callable] = None,
train_target_transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
test_target_transform: Optional[Callable] = None,
buffer_transform: Optional[Callable] = None,
buffer_target_transform: Optional[Callable] = None,
metric: Optional[str] = None,
mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min",
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
early_stopping_enabled: bool = False,
logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS),
accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR,
devices: Optional[int] = None,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
seed: int = defaults.SEED,
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
"alpha": alpha,
"batch_size": batch_size,
"seed": seed,
}
super().__init__(
model,
loss_fn=loss_fn,
optimizer=optimizer,
learner_class=ExperienceReplayLearner,
learner_kwargs=learner_kwargs,
input_state_folder=input_state_folder,
output_state_folder=output_state_folder,
max_epochs=max_epochs,
learning_rate_scheduler=learning_rate_scheduler,
learning_rate_scheduler_interval=learning_rate_scheduler_interval,
train_transform=train_transform,
train_target_transform=train_target_transform,
test_transform=test_transform,
test_target_transform=test_target_transform,
buffer_transform=buffer_transform,
buffer_target_transform=buffer_target_transform,
metric=metric,
mode=mode,
logged_metrics=logged_metrics,
early_stopping_enabled=early_stopping_enabled,
logger=logger,
accelerator=accelerator,
devices=devices,
strategy=strategy,
precision=precision,
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
[docs]
class DarkExperienceReplayModelUpdater(SingleTrainingLoopUpdater):
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
alpha: float = defaults.DER_ALPHA,
beta: float = defaults.DER_BETA,
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
input_state_folder: Optional[str] = None,
output_state_folder: Optional[str] = None,
max_epochs: int = defaults.MAX_EPOCHS,
train_transform: Optional[Callable] = None,
train_target_transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
test_target_transform: Optional[Callable] = None,
buffer_transform: Optional[Callable] = None,
buffer_target_transform: Optional[Callable] = None,
metric: Optional[str] = None,
mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min",
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
early_stopping_enabled: bool = False,
logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS),
accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR,
devices: Optional[int] = None,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
seed: int = defaults.SEED,
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
"alpha": alpha,
"beta": beta,
"batch_size": batch_size,
"seed": seed,
}
super().__init__(
model,
loss_fn=loss_fn,
optimizer=optimizer,
learner_class=DarkExperienceReplayLearner,
learner_kwargs=learner_kwargs,
input_state_folder=input_state_folder,
output_state_folder=output_state_folder,
max_epochs=max_epochs,
learning_rate_scheduler=learning_rate_scheduler,
learning_rate_scheduler_interval=learning_rate_scheduler_interval,
train_transform=train_transform,
train_target_transform=train_target_transform,
test_transform=test_transform,
test_target_transform=test_target_transform,
buffer_transform=buffer_transform,
buffer_target_transform=buffer_target_transform,
metric=metric,
mode=mode,
logged_metrics=logged_metrics,
early_stopping_enabled=early_stopping_enabled,
logger=logger,
accelerator=accelerator,
devices=devices,
strategy=strategy,
precision=precision,
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
[docs]
class PooledOutputDistillationExperienceReplayModelUpdater(SingleTrainingLoopUpdater):
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
alpha: float = defaults.POD_ALPHA,
distillation_type: str = defaults.POD_DISTILLATION_TYPE,
normalize: bool = defaults.POD_NORMALIZE,
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
input_state_folder: Optional[str] = None,
output_state_folder: Optional[str] = None,
max_epochs: int = defaults.MAX_EPOCHS,
train_transform: Optional[Callable] = None,
train_target_transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
test_target_transform: Optional[Callable] = None,
buffer_transform: Optional[Callable] = None,
buffer_target_transform: Optional[Callable] = None,
metric: Optional[str] = None,
mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min",
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
early_stopping_enabled: bool = False,
logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS),
accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR,
devices: Optional[int] = None,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
seed: int = defaults.SEED,
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
"alpha": alpha,
"distillation_type": distillation_type,
"normalize": normalize,
"batch_size": batch_size,
"seed": seed,
}
super().__init__(
model,
loss_fn=loss_fn,
optimizer=optimizer,
learner_class=PooledOutputDistillationExperienceReplayLearner,
learner_kwargs=learner_kwargs,
input_state_folder=input_state_folder,
output_state_folder=output_state_folder,
max_epochs=max_epochs,
learning_rate_scheduler=learning_rate_scheduler,
learning_rate_scheduler_interval=learning_rate_scheduler_interval,
train_transform=train_transform,
train_target_transform=train_target_transform,
test_transform=test_transform,
test_target_transform=test_target_transform,
buffer_transform=buffer_transform,
buffer_target_transform=buffer_target_transform,
metric=metric,
mode=mode,
logged_metrics=logged_metrics,
early_stopping_enabled=early_stopping_enabled,
logger=logger,
accelerator=accelerator,
devices=devices,
strategy=strategy,
precision=precision,
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
[docs]
class CLSExperienceReplayModelUpdater(SingleTrainingLoopUpdater):
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
alpha: float = defaults.CLS_ALPHA,
beta: float = defaults.CLS_BETA,
stable_model_update_weight: float = defaults.CLS_STABLE_MODEL_UPDATE_WEIGHT,
plastic_model_update_weight: float = defaults.CLS_PLASTIC_MODEL_UPDATE_WEIGHT,
stable_model_update_probability: float = defaults.CLS_STABLE_MODEL_UPDATE_PROBABILITY,
plastic_model_update_probability: float = defaults.CLS_PLASTIC_MODEL_UPDATE_PROBABILITY,
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
input_state_folder: Optional[str] = None,
output_state_folder: Optional[str] = None,
max_epochs: int = defaults.MAX_EPOCHS,
train_transform: Optional[Callable] = None,
train_target_transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
test_target_transform: Optional[Callable] = None,
buffer_transform: Optional[Callable] = None,
buffer_target_transform: Optional[Callable] = None,
metric: Optional[str] = None,
mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min",
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
early_stopping_enabled: bool = False,
logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS),
accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR,
devices: Optional[int] = None,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
seed: int = defaults.SEED,
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
"alpha": alpha,
"beta": beta,
"stable_model_update_weight": stable_model_update_weight,
"plastic_model_update_weight": plastic_model_update_weight,
"stable_model_update_probability": stable_model_update_probability,
"plastic_model_update_probability": plastic_model_update_probability,
"batch_size": batch_size,
"seed": seed,
}
super().__init__(
model,
loss_fn=loss_fn,
optimizer=optimizer,
learner_class=CLSExperienceReplayLearner,
learner_kwargs=learner_kwargs,
input_state_folder=input_state_folder,
output_state_folder=output_state_folder,
max_epochs=max_epochs,
learning_rate_scheduler=learning_rate_scheduler,
learning_rate_scheduler_interval=learning_rate_scheduler_interval,
train_transform=train_transform,
train_target_transform=train_target_transform,
test_transform=test_transform,
test_target_transform=test_target_transform,
buffer_transform=buffer_transform,
buffer_target_transform=buffer_target_transform,
metric=metric,
mode=mode,
logged_metrics=logged_metrics,
early_stopping_enabled=early_stopping_enabled,
logger=logger,
accelerator=accelerator,
devices=devices,
strategy=strategy,
precision=precision,
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
[docs]
class SuperExperienceReplayModelUpdater(SingleTrainingLoopUpdater):
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
der_alpha: float = defaults.SER_DER_ALPHA,
der_beta: float = defaults.SER_DER_BETA,
sp_shrink_factor: float = defaults.SER_SP_SHRINK_FACTOR,
sp_sigma: float = defaults.SER_SP_SIGMA,
cls_alpha: float = defaults.SER_CLS_ALPHA,
cls_stable_model_update_weight: float = defaults.SER_CLS_STABLE_MODEL_UPDATE_WEIGHT,
cls_plastic_model_update_weight: float = defaults.SER_CLS_PLASTIC_MODEL_UPDATE_WEIGHT,
cls_stable_model_update_probability: float = defaults.SER_CLS_STABLE_MODEL_UPDATE_PROBABILITY, # noqa: E501
cls_plastic_model_update_probability: float = defaults.SER_CLS_PLASTIC_MODEL_UPDATE_PROBABILITY, # noqa: E501
pod_alpha: float = defaults.SER_POD_ALPHA,
pod_distillation_type: str = defaults.SER_POD_DISTILLATION_TYPE,
pod_normalize: bool = defaults.SER_POD_NORMALIZE,
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
input_state_folder: Optional[str] = None,
output_state_folder: Optional[str] = None,
max_epochs: int = defaults.MAX_EPOCHS,
train_transform: Optional[Callable] = None,
train_target_transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
test_target_transform: Optional[Callable] = None,
buffer_transform: Optional[Callable] = None,
buffer_target_transform: Optional[Callable] = None,
metric: Optional[str] = None,
mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min",
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
early_stopping_enabled: bool = False,
logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS),
accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR,
devices: Optional[int] = None,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
seed: int = defaults.SEED,
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
"der_alpha": der_alpha,
"der_beta": der_beta,
"sp_shrink_factor": sp_shrink_factor,
"sp_sigma": sp_sigma,
"cls_alpha": cls_alpha,
"cls_stable_model_update_weight": cls_stable_model_update_weight,
"cls_plastic_model_update_weight": cls_plastic_model_update_weight,
"cls_stable_model_update_probability": cls_stable_model_update_probability,
"cls_plastic_model_update_probability": cls_plastic_model_update_probability,
"pod_alpha": pod_alpha,
"pod_distillation_type": pod_distillation_type,
"pod_normalize": pod_normalize,
"batch_size": batch_size,
"seed": seed,
}
super().__init__(
model,
loss_fn=loss_fn,
optimizer=optimizer,
learner_class=SuperExperienceReplayLearner,
learner_kwargs=learner_kwargs,
input_state_folder=input_state_folder,
output_state_folder=output_state_folder,
max_epochs=max_epochs,
learning_rate_scheduler=learning_rate_scheduler,
learning_rate_scheduler_interval=learning_rate_scheduler_interval,
train_transform=train_transform,
train_target_transform=train_target_transform,
test_transform=test_transform,
test_target_transform=test_target_transform,
buffer_transform=buffer_transform,
buffer_target_transform=buffer_target_transform,
metric=metric,
mode=mode,
logged_metrics=logged_metrics,
early_stopping_enabled=early_stopping_enabled,
logger=logger,
accelerator=accelerator,
devices=devices,
strategy=strategy,
precision=precision,
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)