How to Write a Config File#

User input is passed to Renate via a config file. It contains the definition of the model you want to train, code to load and preprocess the data, and (optionally) a set of transforms to be applied to that data. These components are provided by implementing functions with a fixed name. When accessing your config file, Renate will inspect it for these functions.

Model Definition#

This function takes a path to a model state and returns a model in the form of a RenateModule. Its signature is

def model_fn(model_state_url: Optional[str] = None) -> RenateModule:

A RenateModule is a torch.nn.Module with some additional functionality relevant to continual learning. If no path is given (i.e., when we first train a model) your model_fn should create the model from scratch. Otherwise it should be reloaded from the stored state, for which RenateModule provides a from_state_dict() method, which automatically handles model hyperparameters.

Example#
class MyMNISTMLP(RenateModule):
    def __init__(self, num_hidden: int) -> None:
        # Model hyperparameters need to registered via RenateModule's
        # constructor, see documentation. Otherwise, this is a standard torch model.
        super().__init__(constructor_arguments={"num_hidden": num_hidden})
        self._fc1 = torch.nn.Linear(28 * 28, num_hidden)
        self._fc2 = torch.nn.Linear(num_hidden, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._fc1(x)
        x = torch.nn.functional.relu(x)
        return self._fc2(x)


def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
    if model_state_url is None:
        # If no model state is given, we create the model from scratch with initial model
        # hyperparameters.
        model = MyMNISTMLP(num_hidden=100)
    else:
        # If a model state is passed, we reload the model using PyTorch's load_state_dict.
        # In this case, model hyperparameters are restored from the saved state.
        state_dict = torch.load(model_state_url)
        model = MyMNISTMLP.from_state_dict(state_dict)
    return model

If you are using a torch model with no or fixed hyperparameters, you can use RenateWrapper. In this case, do not use the from_state_dict() method, but simply reinstantiate your model and call load_state_dict.

Example#
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
    my_torch_model = torch.nn.Linear(28 * 28, 10)  # Instantiate your torch model.
    model = RenateWrapper(my_torch_model)
    if model_state_url is not None:
        state_dict = torch.load(str(model_state_url))
        model.load_state_dict(state_dict)
    return model

Loss Definition#

This function returns a torch.nn.Module object that computes the loss with the signature

def loss_fn() -> torch.nn.Module:

An example of this for the task of MNIST classification above as

Loss function example#
def loss_fn() -> torch.nn.Module:
    return torch.nn.CrossEntropyLoss(reduction="none")

Please note, loss functions should not be reduced.

Data Preparation#

This function takes a path to a data folder and returns data in the form of a RenateDataModule. Its signature is

def data_module_fn(data_path: str, seed: int = defaults.SEED) -> RenateDataModule:

RenateDataModule provides a structured interface to download, set up, and access train/val/test datasets. The function also accepts a seed, which should be used for any randomized operations, such as data subsampling or splitting.

Example#
class MyMNISTDataModule(RenateDataModule):
    def __init__(self, data_path: str, val_size: float, seed: int = 42) -> None:
        super().__init__(data_path, val_size=val_size, seed=seed)

    def prepare_data(self) -> None:
        # This is only to download the data. We separate downloading from the remaining set-up to
        # streamline data loading when using multiple training jobs during HPO.
        torchvision.datasets.MNIST(self._data_path, download=True)

    def setup(self) -> None:
        # This sets up train/val/test datasets, assuming data has already been downloaded.
        train_data = torchvision.datasets.MNIST(
            self._data_path,
            train=True,
            transform=transforms.ToTensor(),
            target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)),
        )
        self._train_data, self._val_data = self._split_train_val_data(train_data)
        self._test_data = torchvision.datasets.MNIST(
            self._data_path,
            train=False,
            transform=transforms.ToTensor(),
            target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)),
        )


def data_module_fn(data_path: str, seed: int) -> RenateDataModule:
    return MyMNISTDataModule(val_size=0.2, seed=seed)

Optimizer#

Optimizers such as SGD or Adam can be selected by passing the corresponding arguments. If you want to use other optimizers, you can do so by returning a partial optimizer object as outlined in the example below.

Example#
def optimizer_fn() -> Callable[[Generator[Parameter]], Optimizer]:
    return partial(AdamW, lr=0.01, weight_decay=0.0)

Learning Rate Schedulers#

A learning rate scheduler can be provided by creating a function as demonstrated below. This function will need to return a partial object of a learning rate scheduler as well as a string that indicates whether the scheduler is updated after each epoch or after each step.

Example#
def lr_scheduler_fn() -> Tuple[Callable[[Optimizer], _LRScheduler], str]:
    return partial(StepLR, step_size=10, gamma=0.1), "epoch"

Transforms#

Transforms for data preprocessing or augmentation are often applied as part of torch datasets. That is, x, y = dataset[i] returns a fully-preprocessed and potentially augmented data point, ready to be passed to a torch model.

In Renate, transforms should, to some extent, be handled _outside_ of the dataset object. This is because many continual learning methods maintain a memory of previously-encountered data points. Having access to the _raw_, _untransformed_ data points allows us to store this data in a memory-efficient way and ensures that data augmentation operations do not cumulate over time. Explicit access to the preprocessing transforms is also useful when deploying a trained model.

It is on the user to decide which transforms to apply inside the dataset and which to pass to Renate explicitly. As a general rule, dataset[i] should return a torch.Tensor of fixed size and data type. Randomized data augmentation operations should be passed explicitly.

Transforms are specified in the config file via four functions

  • def train_transform() -> Callable

  • def train_target_transform() -> Callable

  • def test_transform() -> Callable

  • def test_target_transform() -> Callable

which return a transform in the form of a (single) callable. These are applied to train and test data, as well as inputs (X) and targets (y), respectively. The transform functions are optional and each of them can be omitted if no respective transform should be applied.

Some methods perform a separate set of transforms to data kept in a memory buffer, e.g., for enhanced augmentation. These can be set via two addition transform functions

  • def buffer_transform() -> Callable

  • def buffer_target_transform() -> Callable

These are optional as well but, if omitted, Renate will use train_transform and train_target_transform, respectively.

Example#
def train_transform() -> Callable:
    return torchvision.transforms.Compose(
        [torchvision.transforms.RandomCrop((28, 28), padding=4), torch.nn.Flatten()]
    )


def test_transform() -> Callable:
    return torch.nn.Flatten()

Custom Metrics#

It is possible to specify a set of custom metrics to be measured during the training process. The metrics can be either imported from torchmetrics, which offers a vast collection, or created ad-hoc by implementing the same interface (see this tutorial).

Example#
def metrics_fn() -> Dict:
    return {"accuracy": Accuracy(task="multiclass", num_classes=10)}

To enable the usage of additional metrics in Renate it is sufficient to implement the metrics_fn function, returning a dictionary where the key is a string containing the metric’s name and the value is an instantiation of the metric class. In the example above we add a metric called my_accuracy by instantiating the accuracy metric from torchmetrics.

Custom Function Arguments#

In many cases, the standard arguments passed to all functions described above are not sufficient. More arguments can be added by simply adding them to the interface (with some limitations). We will demonstrate this at the example of data_module_fn but the same rules apply to all other functions introduced in this chapter.

Let us assume we already have a config file in which we implemented a simple linear model:

def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
    my_torch_model = torch.nn.Linear(28 * 28, 10)
    model = RenateWrapper(my_torch_model)
    if model_state_url is not None:
        state_dict = torch.load(model_state_url)
        model.load_state_dict(state_dict)
    return model

However, we have different datasets and each of them has different input and output dimensions. The natural change would be to change it to something like

def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[str] = None) -> RenateModule:
    my_torch_model = torch.nn.Linear(num_inputs, num_outputs)
    model = RenateWrapper(my_torch_model)
    if model_state_url is not None:
        state_dict = torch.load(model_state_url)
        model.load_state_dict(state_dict)
    return model

And in fact, this is exactly how it works. However, there are few limitations:

  • Typing is required.

  • Only types allowed: bool, float, int, str, list, and tuple. (typing with List, Tuple or Optional is okay)

How to set the actual values, will be discussed in the next chapter.

Note

You can use an argument with the same name in the different functions as long as they have the same typing. The same value will provided to them.