# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List
import torch
from avalanche.core import BasePlugin, SupervisedPlugin
from avalanche.models import NCMClassifier, TrainEvalModel
from avalanche.training import ICaRL, ICaRLLossPlugin
from avalanche.training.plugins import EWCPlugin, LwFPlugin, ReplayPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.supervised.icarl import _ICaRLPlugin
from avalanche.training.templates import BaseSGDTemplate, SupervisedTemplate
from torch.optim import Optimizer
from renate.updaters.avalanche.plugins import RenateCheckpointPlugin
from renate.updaters.learner import Learner, ReplayLearner
from renate.utils.avalanche import plugin_by_class, remove_plugin, replace_plugin
[docs]
class AvalancheLoaderMixin:
"""Mixin for Avalanche dummy learner classes."""
[docs]
def update_settings(
self,
avalanche_learner: BaseSGDTemplate,
plugins: List[BasePlugin],
optimizer: Optimizer,
max_epochs: int,
device: torch.device,
eval_every: int,
) -> None:
"""Updates settings of Avalanche learner after reloading."""
avalanche_learner.plugins = remove_plugin(RenateCheckpointPlugin, avalanche_learner.plugins)
for plugin in plugins:
avalanche_learner.plugins = replace_plugin(plugin, avalanche_learner.plugins)
avalanche_learner.model = self._model
avalanche_learner.optimizer = optimizer
avalanche_learner._criterion = self._loss_fn
avalanche_learner.train_epochs = max_epochs
avalanche_learner.train_mb_size = self._batch_size
avalanche_learner.eval_mb_size = self._batch_size + getattr(self, "_memory_batch_size", 0)
avalanche_learner.device = device
avalanche_learner.eval_every = eval_every
def _create_avalanche_learner(
self,
optimizer: Optimizer,
train_epochs: int,
plugins: List[SupervisedPlugin],
device: torch.device,
eval_every: int,
**kwargs: Any,
) -> BaseSGDTemplate:
"""Returns Avalanche object that this dummy learner wraps around."""
return SupervisedTemplate(
model=self._model,
optimizer=optimizer,
criterion=self._loss_fn,
train_mb_size=self._batch_size,
eval_mb_size=self._batch_size + getattr(self, "_memory_batch_size", 0),
train_epochs=train_epochs,
plugins=plugins,
evaluator=default_evaluator(),
device=device,
eval_every=eval_every,
**kwargs,
)
[docs]
class AvalancheReplayLearner(ReplayLearner, AvalancheLoaderMixin):
"""Renate wrapper around Avalanche Experience Replay."""
[docs]
def create_avalanche_learner(
self, plugins: List[SupervisedPlugin], **kwargs: Any
) -> BaseSGDTemplate:
replay_plugin = ReplayPlugin(
mem_size=self._memory_buffer._max_size,
batch_size=self._batch_size,
batch_size_mem=self._memory_batch_size,
)
plugins.append(replay_plugin)
return self._create_avalanche_learner(plugins=plugins, **kwargs)
[docs]
class AvalancheEWCLearner(Learner, AvalancheLoaderMixin):
"""Renate wrapper around Avalanche EWC."""
def __init__(self, ewc_lambda: float, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._ewc_lambda = ewc_lambda
[docs]
def update_settings(self, avalanche_learner: BaseSGDTemplate, **kwargs: Any):
super().update_settings(avalanche_learner=avalanche_learner, **kwargs)
plugin_by_class(EWCPlugin, avalanche_learner.plugins).ewc_lambda = self._ewc_lambda
[docs]
def create_avalanche_learner(
self, plugins: List[SupervisedPlugin], **kwargs
) -> BaseSGDTemplate:
ewc_plugin = EWCPlugin(ewc_lambda=self._ewc_lambda)
plugins.append(ewc_plugin)
return self._create_avalanche_learner(plugins=plugins, **kwargs)
[docs]
class AvalancheLwFLearner(Learner, AvalancheLoaderMixin):
"""Renate wrapper around Avalanche LwF"""
def __init__(self, alpha: float, temperature: float, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._alpha = alpha
self._temperature = temperature
[docs]
def update_settings(self, avalanche_learner: BaseSGDTemplate, **kwargs: Any):
super().update_settings(avalanche_learner=avalanche_learner, **kwargs)
lwf_plugin = plugin_by_class(LwFPlugin, avalanche_learner.plugins)
lwf_plugin.lwf.alpha = self._alpha
lwf_plugin.lwf.temperature = self._temperature
[docs]
def create_avalanche_learner(
self, plugins: List[SupervisedPlugin], **kwargs
) -> BaseSGDTemplate:
lwf_plugin = LwFPlugin(alpha=self._alpha, temperature=self._temperature)
plugins.append(lwf_plugin)
return self._create_avalanche_learner(plugins=plugins, **kwargs)
[docs]
class AvalancheICaRLLearner(Learner, AvalancheLoaderMixin):
"""Renate wrapper around Avalanche ICaRL."""
def __init__(self, memory_size: int, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._memory_size = memory_size
[docs]
def create_avalanche_learner(
self,
optimizer: Optimizer,
train_epochs: int,
plugins: List[SupervisedPlugin],
device: torch.device,
eval_every: int,
) -> BaseSGDTemplate:
if not hasattr(self._model, "class_means"):
raise RuntimeError(
"""The RenateModule must contain an attribute `class_means`.
Please add something like
self.class_means = torch.nn.Parameter(
torch.zeros((embedding_size, num_outputs)), requires_grad=False
)
"""
)
if not hasattr(self._model, "get_backbone") or not hasattr(self._model, "get_predictor"):
raise RuntimeError(
"The RenateModule must be explicitly split into backbone and predictor module. "
"Please implement functions `get_backbone()` and `get_predictor()` to return these "
"modules."
)
icarl = ICaRL(
feature_extractor=self._model.get_backbone(),
classifier=self._model.get_predictor(),
optimizer=optimizer,
memory_size=self._memory_size,
buffer_transform=None, # TODO
fixed_memory=True,
train_mb_size=self._batch_size,
train_epochs=train_epochs,
eval_mb_size=self._batch_size,
device=device,
plugins=plugins,
eval_every=-1, # TODO: https://github.com/ContinualAI/avalanche/issues/1281
)
plugin_by_class(_ICaRLPlugin, icarl.plugins).class_means = self._model.class_means
return icarl
[docs]
def update_settings(self, avalanche_learner: BaseSGDTemplate, **kwargs) -> None:
super().update_settings(avalanche_learner=avalanche_learner, **kwargs)
avalanche_learner.model = TrainEvalModel(
feature_extractor=self._model.get_backbone(),
train_classifier=self._model.get_predictor(),
eval_classifier=NCMClassifier(),
)
icarl_loss_plugin = plugin_by_class(ICaRLLossPlugin, avalanche_learner.plugins)
avalanche_learner._criterion = icarl_loss_plugin
avalanche_learner.eval_every = (
-1
) # TODO: https://github.com/ContinualAI/avalanche/issues/1281