Source code for renate.benchmark.models.mlp
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
import torch.nn as nn
from renate.benchmark.models.base import RenateBenchmarkingModule
from renate.models.prediction_strategies import PredictionStrategy
[docs]
class MultiLayerPerceptron(RenateBenchmarkingModule):
"""A simple Multi Layer Perceptron with hidden layers, activation and Batch Normalization if
enabled.
Args:
num_inputs: Number of input nodes.
num_outputs: Number of output nodes.
num_hidden_layers: Number of hidden layers.
hidden_size: Uniform hidden size or the list or tuple of hidden sizes for individual hidden
layers.
activation: Activation name, matching activation name in `torch.nn` to be used between the
hidden layers.
batch_normalization: Whether to use Batch Normalization after the activation. By default the
Batch Normalization tracks the running statistics.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
"""
def __init__(
self,
num_inputs: int,
num_outputs: int,
num_hidden_layers: int,
hidden_size: Union[int, List[int], Tuple[int]],
activation: str = "ReLU",
batch_normalization: bool = False,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
) -> None:
embedding_size = hidden_size if type(hidden_size) == int else hidden_size[-1]
super().__init__(
embedding_size=embedding_size,
num_outputs=num_outputs,
constructor_arguments={
"num_inputs": num_inputs,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
"activation": activation,
"batch_normalization": batch_normalization,
},
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
if isinstance(hidden_size, int):
hidden_size = [hidden_size for _ in range(num_hidden_layers + 1)]
assert len(hidden_size) == num_hidden_layers + 1
activation = getattr(nn, activation)
hidden_size = [num_inputs] + hidden_size
layers = [nn.Flatten()]
for i in range(num_hidden_layers + 1):
layers.append(nn.Linear(hidden_size[i], hidden_size[i + 1]))
layers.append(activation())
if batch_normalization:
layers.append(nn.BatchNorm1d(hidden_size[i + 1]))
self._backbone = nn.Sequential(*layers)