# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import wild_time_data
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, _LRScheduler
from torchmetrics.classification import MulticlassAccuracy
from torchvision.transforms import transforms
from transformers import AutoTokenizer
from wild_time_data import default_transform
from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule, MultiTextDataModule
from renate.benchmark.datasets.vision_datasets import (
CDDBDataModule,
CLEARDataModule,
CORE50DataModule,
DomainNetDataModule,
TorchVisionDataModule,
)
from renate.benchmark.datasets.wild_time_data import WildTimeDataModule
from renate.benchmark.models import (
MultiLayerPerceptron,
LearningToPromptTransformer,
ResNet18,
ResNet18CIFAR,
ResNet34,
ResNet34CIFAR,
ResNet50,
ResNet50CIFAR,
VisionTransformerB16,
VisionTransformerB32,
VisionTransformerCIFAR,
VisionTransformerH14,
VisionTransformerL16,
VisionTransformerL32,
)
from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer
from renate.benchmark.scenarios import (
ClassIncrementalScenario,
DataIncrementalScenario,
FeatureSortingScenario,
HueShiftScenario,
IIDScenario,
ImageRotationScenario,
PermutationScenario,
Scenario,
)
from renate.data.data_module import RenateDataModule
from renate.models import RenateModule
from renate.models.prediction_strategies import ICaRLClassificationStrategy
from renate.benchmark.models.spromptmodel import SPromptTransformer
models = {
"MultiLayerPerceptron": MultiLayerPerceptron,
"ResNet18CIFAR": ResNet18CIFAR,
"ResNet34CIFAR": ResNet34CIFAR,
"ResNet50CIFAR": ResNet50CIFAR,
"ResNet18": ResNet18,
"ResNet34": ResNet34,
"ResNet50": ResNet50,
"VisionTransformerCIFAR": VisionTransformerCIFAR,
"VisionTransformerB16": VisionTransformerB16,
"VisionTransformerB32": VisionTransformerB32,
"VisionTransformerL16": VisionTransformerL16,
"VisionTransformerL32": VisionTransformerL32,
"VisionTransformerH14": VisionTransformerH14,
"HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer,
"LearningToPromptTransformer": LearningToPromptTransformer,
"SPromptTransformer": SPromptTransformer,
}
[docs]
def model_fn(
num_outputs: int,
model_state_url: Optional[str] = None,
updater: Optional[str] = None,
model_name: Optional[str] = None,
num_inputs: Optional[int] = None,
num_hidden_layers: Optional[int] = None,
hidden_size: Optional[Tuple[int]] = None,
dataset_name: Optional[str] = None,
pretrained_model_name_or_path: Optional[str] = None,
prompt_size: int = 10,
clusters_per_task: int = 5,
per_task_classifier: bool = True,
) -> RenateModule:
"""Returns a model instance."""
if model_name not in models:
raise ValueError(f"Unknown model `{model_name}`")
model_class = models[model_name]
model_kwargs = {"num_outputs": num_outputs}
if updater == "Avalanche-iCaRL":
model_kwargs["prediction_strategy"] = ICaRLClassificationStrategy()
if model_name == "MultiLayerPerceptron":
model_kwargs.update(
{
"num_inputs": num_inputs,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
}
)
elif model_name.startswith("ResNet") and dataset_name in ["FashionMNIST", "MNIST", "yearbook"]:
model_kwargs["gray_scale"] = True
elif model_name == "HuggingFaceTransformer":
if updater == "Avalanche-iCaRL":
raise ValueError("Transformers do not support iCaRL.")
model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path
elif (updater is not None) and ("LearningToPrompt" in updater):
if not model_name.startswith("LearningToPrompt"):
raise ValueError(
"L2P model updaters are designed to work only with "
f"LearningToPromptTransformer, but model name specified is {model_name}."
)
model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path
elif (updater is not None) and ("SPeft" in updater):
if not model_name.startswith("SPrompt"):
raise ValueError(
"SPrompt model updater is designed to work only with "
f"SPromptTransformer, but model name specified is {model_name}."
)
model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path
model_kwargs["prompt_size"] = prompt_size
model_kwargs["clusters_per_task"] = clusters_per_task
model_kwargs["per_task_classifier"] = per_task_classifier
if model_state_url is None:
model = model_class(**model_kwargs)
else:
state_dict = torch.load(model_state_url)
model = model_class.from_state_dict(state_dict)
return model
[docs]
def get_data_module(
data_path: str,
src_bucket: Optional[str],
src_object_name: Optional[str],
dataset_name: str,
val_size: float,
seed: int,
pretrained_model_name_or_path: Optional[str],
input_column: Optional[str],
target_column: Optional[str],
) -> RenateDataModule:
tokenizer = None
if pretrained_model_name_or_path is not None and "vit" not in pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
if dataset_name in TorchVisionDataModule.dataset_dict:
return TorchVisionDataModule(
data_path, dataset_name=dataset_name, val_size=val_size, seed=seed
)
if dataset_name in ["CLEAR10", "CLEAR100"]:
return CLEARDataModule(data_path, dataset_name=dataset_name, val_size=val_size, seed=seed)
if dataset_name in wild_time_data.list_datasets():
data_module_kwargs = {
"data_path": data_path,
"src_bucket": src_bucket,
"src_object_name": src_object_name,
"dataset_name": dataset_name,
"val_size": val_size,
"seed": seed,
}
if pretrained_model_name_or_path is not None:
data_module_kwargs["tokenizer"] = tokenizer
return WildTimeDataModule(**data_module_kwargs)
if dataset_name == "DomainNet":
return DomainNetDataModule(
data_path=data_path,
src_bucket=src_bucket,
src_object_name=src_object_name,
val_size=val_size,
seed=seed,
)
if dataset_name == "MultiText":
return MultiTextDataModule(
data_path=data_path,
tokenizer=tokenizer,
data_id="ag_news",
val_size=val_size,
seed=seed,
)
if dataset_name.startswith("hfd-"):
return HuggingFaceTextDataModule(
data_path=data_path,
dataset_name=dataset_name[4:],
input_column=input_column,
target_column=target_column,
tokenizer=tokenizer,
val_size=val_size,
seed=seed,
)
if dataset_name == "CDDB":
return CDDBDataModule(
data_path=data_path,
src_bucket=src_bucket,
src_object_name=src_object_name,
val_size=val_size,
seed=seed,
)
if dataset_name == "Core50":
return CORE50DataModule(
data_path=data_path,
src_bucket=src_bucket,
src_object_name=src_object_name,
val_size=val_size,
seed=seed,
)
raise ValueError(f"Unknown dataset `{dataset_name}`.")
[docs]
def get_scenario(
scenario_name: str,
data_module: RenateDataModule,
chunk_id: int,
seed: int,
num_tasks: Optional[int] = None,
groupings: Optional[Tuple[Tuple[int]]] = None,
degrees: Optional[List[int]] = None,
input_dim: Optional[Union[List[int], Tuple[int], int]] = None,
feature_idx: Optional[int] = None,
randomness: Optional[float] = None,
data_ids: Optional[Tuple[Union[int, str]]] = None,
) -> Scenario:
"""Function to create scenario based on name and arguments.
Args:
scenario_name: Name to identify scenario.
data_module: The base data module.
chunk_id: The data chunk to load in for the training or validation data.
seed: A random seed to fix the created scenario.
num_tasks: The total number of expected tasks for experimentation.
groupings: Used for scenario `ClassIncrementalScenario` to partition datasets into chunks by
class. Used by `DataIncrementalScenario` to group domains to chunks..
degrees: Used for scenario `ImageRotationScenario`. Rotations applied for each chunk.
input_dim: Used for scenario `PermutationScenario`. Input dimensionality.
feature_idx: Used for scenario `SoftSortingScenario`. Index of feature to sort by.
randomness: Used for all `_SortingScenario`. Randomness strength in [0, 1].
data_ids: List of data_ids for the `DataIncrementalScenario`.
Returns:
An instance of the requested scenario.
Raises:
ValueError: If scenario name is unknown.
"""
if scenario_name == "ClassIncrementalScenario":
assert groupings is not None, "Provide `groupings` for the class-incremental scenario."
return ClassIncrementalScenario(
data_module=data_module,
groupings=groupings,
chunk_id=chunk_id,
)
if scenario_name == "IIDScenario":
return IIDScenario(
data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed
)
if scenario_name == "ImageRotationScenario":
return ImageRotationScenario(
data_module=data_module, degrees=degrees, chunk_id=chunk_id, seed=seed
)
if scenario_name == "PermutationScenario":
return PermutationScenario(
data_module=data_module,
num_tasks=num_tasks,
input_dim=input_dim,
chunk_id=chunk_id,
seed=seed,
)
if scenario_name == "FeatureSortingScenario":
return FeatureSortingScenario(
data_module=data_module,
num_tasks=num_tasks,
feature_idx=feature_idx,
randomness=randomness,
chunk_id=chunk_id,
seed=seed,
)
if scenario_name == "HueShiftScenario":
return HueShiftScenario(
data_module=data_module,
num_tasks=num_tasks,
randomness=randomness,
chunk_id=chunk_id,
seed=seed,
)
if scenario_name == "DataIncrementalScenario":
if data_ids is None and groupings is None:
data_ids = [data_id for data_id in range(num_tasks)]
return DataIncrementalScenario(
data_module=data_module,
chunk_id=chunk_id,
data_ids=data_ids,
groupings=groupings,
seed=seed,
)
raise ValueError(f"Unknown scenario `{scenario_name}`.")
[docs]
def loss_fn(updater: Optional[str] = None) -> torch.nn.Module:
if updater.startswith("Avalanche-"):
return torch.nn.CrossEntropyLoss()
return torch.nn.CrossEntropyLoss(reduction="none")
[docs]
def data_module_fn(
data_path: str,
chunk_id: int,
seed: int,
scenario_name: str,
dataset_name: str,
val_size: float = 0.0,
num_tasks: Optional[int] = None,
groupings: Optional[Tuple[Tuple[int]]] = None,
degrees: Optional[Tuple[int]] = None,
input_dim: Optional[Tuple[int]] = None,
feature_idx: Optional[int] = None,
randomness: Optional[float] = None,
src_bucket: Optional[str] = None,
src_object_name: Optional[str] = None,
pretrained_model_name_or_path: Optional[str] = None,
input_column: Optional[str] = None,
target_column: Optional[str] = None,
data_ids: Optional[List[Union[int, str]]] = None,
):
data_module = get_data_module(
data_path=data_path,
src_bucket=src_bucket,
src_object_name=src_object_name,
dataset_name=dataset_name,
val_size=val_size,
seed=seed,
pretrained_model_name_or_path=pretrained_model_name_or_path,
input_column=input_column,
target_column=target_column,
)
if dataset_name in wild_time_data.list_datasets() and num_tasks is None:
num_tasks = len(wild_time_data.available_time_steps(dataset_name))
return get_scenario(
scenario_name=scenario_name,
data_module=data_module,
chunk_id=chunk_id,
seed=seed,
num_tasks=num_tasks,
groupings=groupings,
degrees=degrees,
input_dim=input_dim,
feature_idx=feature_idx,
randomness=randomness,
data_ids=data_ids,
)
def _get_normalize_transform(dataset_name):
if dataset_name in TorchVisionDataModule.dataset_stats:
return transforms.Normalize(
TorchVisionDataModule.dataset_stats[dataset_name]["mean"],
TorchVisionDataModule.dataset_stats[dataset_name]["std"],
)
if dataset_name in ["CLEAR10", "CLEAR100"]:
return transforms.Normalize(
CLEARDataModule.dataset_stats[dataset_name]["mean"],
CLEARDataModule.dataset_stats[dataset_name]["std"],
)
if dataset_name == "DomainNet":
return transforms.Normalize(
DomainNetDataModule.dataset_stats["all"]["mean"],
DomainNetDataModule.dataset_stats["all"]["std"],
)
if dataset_name == "CDDB":
return transforms.Normalize(
CDDBDataModule.dataset_stats["CDDB"]["mean"],
CDDBDataModule.dataset_stats["CDDB"]["std"],
)
if dataset_name == "Core50":
return transforms.Normalize(
CORE50DataModule.dataset_stats["Core50"]["mean"],
CORE50DataModule.dataset_stats["Core50"]["std"],
)
[docs]
def lr_scheduler_fn(
learning_rate_scheduler: Optional[str] = None,
learning_rate_scheduler_step_size: int = 30,
learning_rate_scheduler_gamma: float = 0.1,
learning_rate_scheduler_interval: str = "epoch",
learning_rate_scheduler_t_max: Optional[int] = None,
learning_rate_scheduler_eta_min: float = 0,
) -> Tuple[Optional[Callable[[Optimizer], _LRScheduler]], str]:
if learning_rate_scheduler == "StepLR":
return (
partial(
StepLR,
step_size=learning_rate_scheduler_step_size,
gamma=learning_rate_scheduler_gamma,
),
learning_rate_scheduler_interval,
)
elif learning_rate_scheduler == "CosineAnnealingLR":
return (
partial(
CosineAnnealingLR,
T_max=learning_rate_scheduler_t_max,
eta_min=learning_rate_scheduler_eta_min,
),
learning_rate_scheduler_interval,
)
elif learning_rate_scheduler is None:
return None, learning_rate_scheduler_interval
raise ValueError(f"Unknown scheduler `{learning_rate_scheduler}`.")
[docs]
def metrics_fn(num_outputs: int) -> Dict:
return {"accuracy": MulticlassAccuracy(num_classes=num_outputs, average="micro")}
[docs]
def optimizer_fn(
optimizer: str,
learning_rate: float,
weight_decay: float,
momentum: float = 0.0, # TODO: fix problem that occurs when removing this
) -> Callable:
if optimizer == "AdamW":
return partial(AdamW, lr=learning_rate, weight_decay=weight_decay)