Source code for renate.utils.avalanche

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections import Counter
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
from avalanche.benchmarks import dataset_benchmark
from avalanche.core import BasePlugin
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from renate.data.datasets import _TransformedDataset
from renate.memory import DataBuffer


[docs] class BaseAvalancheDataset(Dataset): """Base class for all datasets consumable by Avalanche updaters.""" def __init__( self, targets: List[int], collate_fn: Optional[Callable] = None, ): self._targets = targets self.targets = torch.tensor(targets, dtype=torch.long) if collate_fn is not None: self.collate_fn = collate_fn def __len__(self) -> int: return len(self._targets)
[docs] class AvalancheDataset(BaseAvalancheDataset): """A wrapper around a Dataset consumable by Avalanche updaters.""" def __init__( self, dataset: Union[Dataset, DataBuffer], targets: List[int], collate_fn: Optional[Callable] = None, ): super().__init__(targets, collate_fn) self._dataset = dataset def __getitem__(self, idx) -> Tuple[Tensor, int]: return self._dataset[idx][0], self._targets[idx]
[docs] class AvalancheDatasetForBuffer(BaseAvalancheDataset): """A wrapper around a DataBuffer consumable by Avalanche updaters.""" def __init__( self, buffer: DataBuffer, targets: List[int], collate_fn: Optional[Callable] = None ): super().__init__(targets, collate_fn) self._indices = buffer._indices self._datasets = buffer._datasets def __getitem__(self, idx) -> Tuple[Tensor, int]: i, j = self._indices[idx] return self._datasets[i][j][0], self._targets[idx]
[docs] def to_avalanche_dataset( dataset: Union[Dataset, DataBuffer], collate_fn: Optional[Callable] = None ) -> BaseAvalancheDataset: """Converts a DataBuffer or Dataset into an Avalanche-compatible Dataset.""" y_data = [] is_buffer = isinstance(dataset, DataBuffer) for i in range(len(dataset)): if is_buffer: (_, y), _ = dataset[i] else: _, y = dataset[i] if not isinstance(y, int): y = y.item() y_data.append(y) if is_buffer: return AvalancheDatasetForBuffer(dataset, y_data, collate_fn) return AvalancheDataset(dataset, y_data, collate_fn)
[docs] class AvalancheBenchmarkWrapper: def __init__( self, train_dataset, val_dataset, train_transform, train_target_transform, test_transform, test_target_transform, ): self._n_classes_per_exp = None self._classes_order = None self._n_classes = 0 self._train_dataset = train_dataset self._test_transform = test_transform self._train_target_transform = train_target_transform self._benchmark = dataset_benchmark( [train_dataset], [val_dataset], train_transform=train_transform, train_target_transform=train_target_transform, eval_transform=test_transform, eval_target_transform=test_target_transform, ) self.train_stream = self._benchmark.train_stream self.test_stream = self._benchmark.test_stream
[docs] def update_benchmark_properties(self): dataset = _TransformedDataset( dataset=self._train_dataset, transform=self._test_transform, target_transform=self._train_target_transform, ) dataloader = DataLoader(dataset) unique_classes = set() for batch in dataloader: unique_classes.add(batch[1].item()) if self._n_classes_per_exp is None: self._n_classes_per_exp = [len(unique_classes)] self._classes_order = list(sorted(unique_classes)) else: self._n_classes_per_exp.append(len(unique_classes)) self._classes_order += list(sorted(unique_classes)) self._n_classes = sum(self._classes_order) self._benchmark.n_classes_per_exp = self._n_classes_per_exp self._benchmark.classes_order = self._classes_order
[docs] def state_dict(self) -> Dict[str, Any]: """Returns the state of the benchmark.""" state_dict = { "n_classes_per_exp": self._n_classes_per_exp, "classes_order": self._classes_order, } return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Restores the state of the benchmark.""" self._n_classes_per_exp = state_dict["n_classes_per_exp"] self._classes_order = state_dict["classes_order"] self._benchmark.n_classes_per_exp = self._n_classes_per_exp self._benchmark.classes_order = self._classes_order
[docs] def replace_plugin(plugin: Optional[BasePlugin], plugins: List[BasePlugin]) -> List[BasePlugin]: """Replaces a plugin if already exists and appends otherwise. Args: plugin: New plugin that replaces existing one. plugins: List of current plugins. Returns: Reference to ``plugins``. """ idx = _plugin_index(type(plugin), plugins) if idx >= 0: plugins[idx] = plugin else: plugins.append(plugin) return plugins
[docs] def remove_plugin(plugin_class: Type[BasePlugin], plugins: List[BasePlugin]) -> List[BasePlugin]: """Removes a plugin by class if exists. Args: plugin_class: Remove a plugin of this class if exists. plugins: List of current plugins. Returns: Reference to ``plugins``. """ idx = _plugin_index(plugin_class, plugins) if idx >= 0: del plugins[idx] return plugins
[docs] def plugin_by_class( plugin_class: Type[BasePlugin], plugins: List[BasePlugin] ) -> Optional[BasePlugin]: """Returns plugin with respective class from a list of plugins. Args: plugin_class: Class type of interest in ``plugins``. plugins: List of plugins we search for an object of type ``plugin_class``. Returns: ``None`` if class does not exist, otherwise the respective object. """ idx = _plugin_index(plugin_class, plugins) if idx >= 0: return plugins[idx] return None
def _plugin_index(plugin_class: Type[BasePlugin], plugins: List[BasePlugin]) -> int: """Returns index at which a plugin of that type is located in the list. Returns: Returns location of plugin and ``-1`` if it does not exist. """ plugins_types = [type(p) for p in plugins] if max(Counter(plugins_types).values()) > 1: raise ValueError("Multiple occurrences of same type in `plugins` are not supported.") if plugin_class in plugins_types: return plugins_types.index(plugin_class) return -1