Local Training with MNIST#

This example is designed to demonstrate Renate’s capabilities using a small dataset and a small model in order to run the training locally in a few minutes.


To this purpose, the model_fn function defined in the renate_config.py contains the definition of a Multi-Layer Perceptron. In the same file, we also created data_module_fn, a function loading the MNIST dataset from a public repository.


The MNIST dataset is split in two chunks (first five classes and last five classes respectively) using the ClassIncrementalScenario. The splitting operation is not necessary in real-world applications, but it can be useful to run experiments when testing the library and it was useful for us to create a simple example using a single dataset.

In the renate_config.py file we also create a simple transformation flattening the input images (matrices 28x28) into vectors. Transformations do not provide only the ability to reshape the inputs, but they can be used for normalization, augmentation, and several other purposes. More details on how to write a configuration file are available in How to Write a Config File.

from typing import Callable, Dict, Optional

import torch
from torchmetrics import Accuracy
from torchvision.transforms import transforms

from renate import defaults
from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule
from renate.benchmark.models.mlp import MultiLayerPerceptron
from renate.benchmark.scenarios import ClassIncrementalScenario, Scenario
from renate.models import RenateModule

def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> Scenario:
    """Returns a class-incremental scenario instance.

    The transformations passed to prepare the input data are required to convert the data to
    PyTorch tensors.
    data_module = TorchVisionDataModule(

    class_incremental_scenario = ClassIncrementalScenario(
        groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)),
    return class_incremental_scenario

def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
    """Returns a model instance."""
    if model_state_url is None:
        model = MultiLayerPerceptron(
            num_inputs=784, num_outputs=10, num_hidden_layers=2, hidden_size=128
        state_dict = torch.load(model_state_url)
        model = MultiLayerPerceptron.from_state_dict(state_dict)
    return model

def train_transform() -> Callable:
    """Returns a transform function to be used in the training."""
    return transforms.Lambda(lambda x: torch.flatten(x))

def loss_fn() -> torch.nn.Module:
    return torch.nn.CrossEntropyLoss(reduction="none")

def metrics_fn() -> Dict:
    return {"accuracy": Accuracy(task="multiclass", num_classes=10)}

The configuration file uses a scenario in the definition of the data module function. The scenario is just splitting the dataset in several chunks and allows us to train the model on different parts of the dataset without adding complex code to the example. For most practical purposes the definition of the scenario can be removed from the function.


The example also contains start_training_without_hpo.py, which is the one launching the training jobs. In the file we defined a configuration using the config_space dictionary and pass it to a function launching the training job. The configuration controls a number of aspects of the learning process, for example the learning rate and the optimizer. The list depends on the learning algorithm used for the training. There are also parameters that we pass directly like the folder in which the learner state will be saved (via next_state_url). In order to update an existing model, it will also be necessary to provide the path to the previously saved state using state_url, as done in our example. More details about running training jobs are available in How to Run a Training Job.

from renate.training import run_training_job

config_space = {
    "optimizer": "SGD",
    "momentum": 0.0,
    "weight_decay": 0.0,
    "learning_rate": 0.1,
    "alpha": 0.5,
    "batch_size": 64,
    "batch_memory_frac": 0.5,
    "memory_size": 500,
    "loss_normalization": 0,
    "loss_weight": 0.5,
    "early_stopping": True,

if __name__ == "__main__":
    # we run the first training job on the MNIST classes [0-4]
        chunk_id=0,  # this selects the first chunk of the dataset
        # this is where the model will be stored
        # the training job will run on the local machine

    # retrieve the model from `./state_dump_first_model/` if you want
    # do not delete the model, we are going to use it below

        chunk_id=1,  # this time we use the second chunk of the dataset
        # the output of the first training job is loaded
        # the new model will be stored in this folder


The results obtained by running this simple example are usually quite good with an almost-perfect accuracy on the first chunk of data and an accuracy still above 90% after processing the second one. After the execution is completed, it will be possible to inspect the two different folders containing the learner states.

Another example using Repeated Distillation and HPO is available in the file called start_training_with_hpo.py.