How to Write a Config File#
User input is passed to Renate via a config file. It contains the definition of the model you want to train, code to load and preprocess the data, and (optionally) a set of transforms to be applied to that data. These components are provided by implementing functions with a fixed name. When accessing your config file, Renate will inspect it for these functions.
Model Definition#
This function takes a path to a model state and returns a model in the form of a
RenateModule
.
Its signature is
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
A RenateModule
is a torch.nn.Module
with some
additional functionality relevant to continual learning.
If no path is given (i.e., when we first train a model) your model_fn
should create
the model from scratch.
Otherwise it should be reloaded from the stored state, for which
RenateModule
provides a
from_state_dict()
method, which automatically handles model hyperparameters.
class MyMNISTMLP(RenateModule):
def __init__(self, num_hidden: int) -> None:
# Model hyperparameters need to registered via RenateModule's
# constructor, see documentation. Otherwise, this is a standard torch model.
super().__init__(constructor_arguments={"num_hidden": num_hidden})
self._fc1 = torch.nn.Linear(28 * 28, num_hidden)
self._fc2 = torch.nn.Linear(num_hidden, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._fc1(x)
x = torch.nn.functional.relu(x)
return self._fc2(x)
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
if model_state_url is None:
# If no model state is given, we create the model from scratch with initial model
# hyperparameters.
model = MyMNISTMLP(num_hidden=100)
else:
# If a model state is passed, we reload the model using PyTorch's load_state_dict.
# In this case, model hyperparameters are restored from the saved state.
state_dict = torch.load(model_state_url)
model = MyMNISTMLP.from_state_dict(state_dict)
return model
If you are using a torch model with no or fixed hyperparameters, you can use
RenateWrapper
.
In this case, do not use the
from_state_dict()
method, but simply reinstantiate your model and call load_state_dict
.
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(28 * 28, 10) # Instantiate your torch model.
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
state_dict = torch.load(str(model_state_url))
model.load_state_dict(state_dict)
return model
Loss Definition#
This function returns a torch.nn.Module
object that computes the loss with the
signature
def loss_fn() -> torch.nn.Module:
An example of this for the task of MNIST classification above as
def loss_fn() -> torch.nn.Module:
return torch.nn.CrossEntropyLoss(reduction="none")
Please note, loss functions should not be reduced.
Data Preparation#
This function takes a path to a data folder and returns data in the form of a
RenateDataModule
.
Its signature is
def data_module_fn(data_path: str, seed: int = defaults.SEED) -> RenateDataModule:
RenateDataModule
provides a structured interface to
download, set up, and access train/val/test datasets.
The function also accepts a seed
, which should be used for any randomized operations,
such as data subsampling or splitting.
class MyMNISTDataModule(RenateDataModule):
def __init__(self, data_path: str, val_size: float, seed: int = 42) -> None:
super().__init__(data_path, val_size=val_size, seed=seed)
def prepare_data(self) -> None:
# This is only to download the data. We separate downloading from the remaining set-up to
# streamline data loading when using multiple training jobs during HPO.
torchvision.datasets.MNIST(self._data_path, download=True)
def setup(self) -> None:
# This sets up train/val/test datasets, assuming data has already been downloaded.
train_data = torchvision.datasets.MNIST(
self._data_path,
train=True,
transform=transforms.ToTensor(),
target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)),
)
self._train_data, self._val_data = self._split_train_val_data(train_data)
self._test_data = torchvision.datasets.MNIST(
self._data_path,
train=False,
transform=transforms.ToTensor(),
target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)),
)
def data_module_fn(data_path: str, seed: int) -> RenateDataModule:
return MyMNISTDataModule(val_size=0.2, seed=seed)
Optimizer#
Optimizers such as SGD
or Adam
can be selected by passing the corresponding arguments.
If you want to use other optimizers, you can do so by returning a partial optimizer object as
outlined in the example below.
def optimizer_fn() -> Callable[[Generator[Parameter]], Optimizer]:
return partial(AdamW, lr=0.01, weight_decay=0.0)
Learning Rate Schedulers#
A learning rate scheduler can be provided by creating a function as demonstrated below.
This function will need to return a partial object of a learning rate scheduler as well as a string
that indicates whether the scheduler is updated after each epoch
or after each step
.
def lr_scheduler_fn() -> Tuple[Callable[[Optimizer], _LRScheduler], str]:
return partial(StepLR, step_size=10, gamma=0.1), "epoch"
Transforms#
Transforms for data preprocessing or augmentation are often applied as part of torch datasets.
That is, x, y = dataset[i]
returns a fully-preprocessed and potentially augmented
data point, ready to be passed to a torch model.
In Renate, transforms should, to some extent, be handled _outside_ of the dataset object. This is because many continual learning methods maintain a memory of previously-encountered data points. Having access to the _raw_, _untransformed_ data points allows us to store this data in a memory-efficient way and ensures that data augmentation operations do not cumulate over time. Explicit access to the preprocessing transforms is also useful when deploying a trained model.
It is on the user to decide which transforms to apply inside the dataset and which to
pass to Renate explicitly. As a general rule, dataset[i]
should return a
torch.Tensor
of fixed size and data type. Randomized data augmentation
operations should be passed explicitly.
Transforms are specified in the config file via four functions
def train_transform() -> Callable
def train_target_transform() -> Callable
def test_transform() -> Callable
def test_target_transform() -> Callable
which return a transform in the form of a (single) callable.
These are applied to train and test data, as well as inputs (X
) and targets (y
), respectively.
The transform functions are optional and each of them can be omitted if no respective transform
should be applied.
Some methods perform a separate set of transforms to data kept in a memory buffer, e.g., for enhanced augmentation. These can be set via two addition transform functions
def buffer_transform() -> Callable
def buffer_target_transform() -> Callable
These are optional as well but, if omitted, Renate will use train_transform
and
train_target_transform
, respectively.
def train_transform() -> Callable:
return torchvision.transforms.Compose(
[torchvision.transforms.RandomCrop((28, 28), padding=4), torch.nn.Flatten()]
)
def test_transform() -> Callable:
return torch.nn.Flatten()
Custom Metrics#
It is possible to specify a set of custom metrics to be measured during the training process.
The metrics can be either imported from torchmetrics
, which offers a vast collection,
or created ad-hoc by implementing the same interface
(see this tutorial).
def metrics_fn() -> Dict:
return {"accuracy": Accuracy(task="multiclass", num_classes=10)}
To enable the usage of additional metrics in Renate it is sufficient to implement the
metrics_fn
function, returning a dictionary where the key is a string containing the
metric’s name and the value is an instantiation of the metric class.
In the example above we add a metric called my_accuracy
by instantiating the accuracy
metric from torchmetrics
.
Custom Function Arguments#
In many cases, the standard arguments passed to all functions described above are not sufficient.
More arguments can be added by simply adding them to the interface (with some limitations).
We will demonstrate this at the example of data_module_fn
but the same rules apply to all other functions
introduced in this chapter.
Let us assume we already have a config file in which we implemented a simple linear model:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(28 * 28, 10)
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model
However, we have different datasets and each of them has different input and output dimensions. The natural change would be to change it to something like
def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(num_inputs, num_outputs)
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model
And in fact, this is exactly how it works. However, there are few limitations:
Typing is required.
Only types allowed:
bool
,float
,int
,str
,list
, andtuple
. (typing withList
,Tuple
orOptional
is okay)
How to set the actual values, will be discussed in the next chapter.
Note
You can use an argument with the same name in the different functions as long as they have the same typing. The same value will provided to them.