Training and Tuning on SageMaker#

This example is designed to demonstrate how to use Renate on Amazon SageMaker for both training the model and tuning the hyperparameters required for that. To this purpose we will train a ResNet model on CIFAR10 and tune some hyperparameters using ASHA, an advanced optimizer able to quickly terminate suboptimal hyperparameter configurations.

Configuration#

The model and dataset definitions are in the file renate_config.py. The model_fn function instantiates a ResNet neural network (a common choice for many image classifiers) and the data_module_fn function loads the CIFAR10 dataset.

Note

The CIFAR10 dataset is split in two chunks (first five classes and last five classes respectively) using the ClassIncrementalScenario. The splitting operation is not necessary in real-world applications, but it can be useful to run experiments when testing the library and it was useful for us to create a simple example using a single dataset :)

In the renate_config.py file we also create some simple transformations to normalize and augment the dataset. More transformations can be added if needed, details on how to write a configuration file are available in How to Write a Config File.

from typing import Callable, Dict, Optional

import torch
from torchmetrics import Accuracy
from torchvision import transforms

import renate.defaults as defaults
from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule
from renate.benchmark.models import ResNet18CIFAR
from renate.benchmark.scenarios import ClassIncrementalScenario, Scenario
from renate.models import RenateModule


def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
    """Returns a model instance."""
    if model_state_url is None:
        model = ResNet18CIFAR()
    else:
        state_dict = torch.load(model_state_url)
        model = ResNet18CIFAR.from_state_dict(state_dict)
    return model


def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> Scenario:
    """Returns a class-incremental scenario instance.

    The transformations passed to prepare the input data are required to convert the data to
    PyTorch tensors.
    """
    data_module = TorchVisionDataModule(
        data_path,
        dataset_name="CIFAR10",
        val_size=0.2,
        seed=seed,
    )
    class_incremental_scenario = ClassIncrementalScenario(
        data_module=data_module,
        groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)),
        chunk_id=chunk_id,
    )
    return class_incremental_scenario


def train_transform() -> Callable:
    """Returns a transform function to be used in the training."""
    return transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)),
        ]
    )


def test_transform() -> Callable:
    """Returns a transform function to be used for validation or testing."""
    return transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615))


def buffer_transform() -> Callable:
    """Returns a transform function to be used in the Memory Buffer."""
    return train_transform()


def loss_fn() -> torch.nn.Module:
    return torch.nn.CrossEntropyLoss(reduction="none")


def metrics_fn() -> Dict:
    return {"accuracy": Accuracy(task="multiclass", num_classes=10)}

The configuration file uses a scenario in the definition of the data module function. The scenario is just splitting the dataset in several chunks and allows us to train the model on different parts of the dataset without adding complex code to the example. For most practical purposes the definition of the scenario can just be removed from the function.

Training#

The example also contains start_with_hpo.py, which launches a training job with integrated hyperparameters optimization. To this purpose, in the file we define a dictionary containing the configuration of the learning algorithm. In some cases instead of a single value we define a range (e.g., uniform(0.0, 1.0)) in which the optimizer will try to identify the best value of the hyperparameter. We also specify which algorithm to use for the optimization using the argument scheduler="asha". In this case we will use the ASHA algorithm with 4 workers evaluating up to 100 hyperparameters combinations. The model and the output of the HPO process will be saved in the S3 bucket provided in next_state_url, to simplify the example in this case we provide two variables that can be used to set AWS Account ID and AWS region used, but any accessible S3 bucket can be used for storing the output. The description of the other arguments and a high level overview of how to run a training jobs are available in How to Run a Training Job.

import boto3
from syne_tune.backend.sagemaker_backend.sagemaker_utils import get_execution_role
from syne_tune.config_space import choice, loguniform, uniform

from renate.training import run_training_job

config_space = {
    "optimizer": "SGD",
    "momentum": uniform(0.1, 0.9),
    "weight_decay": 0.0,
    "learning_rate": loguniform(1e-4, 1e-1),
    "alpha": uniform(0.0, 1.0),
    "batch_size": choice([32, 64, 128, 256]),
    "batch_memory_frac": 0.5,
    "memory_size": 1000,
    "loss_normalization": 0,
    "loss_weight": uniform(0.0, 1.0),
}

if __name__ == "__main__":
    AWS_ID = boto3.client("sts").get_caller_identity().get("Account")
    AWS_REGION = "us-west-2"  # use your AWS preferred region here

    run_training_job(
        config_space=config_space,
        mode="max",
        metric="val_accuracy",
        updater="ER",  # we train with Experience Replay
        max_epochs=50,
        # we select the first chunk of our dataset, you will probably not need this in practice
        chunk_id=0,
        config_file="renate_config.py",
        requirements_file="requirements.txt",
        # replace the url below with a different one if you already ran it and you want to avoid
        # overwriting
        output_state_url=f"s3://sagemaker-{AWS_REGION}-{AWS_ID}/renate-cifar10/",
        # uncomment the line below only if you already created a model with this script and you want
        # to update it
        # input_state_url=f"s3://sagemaker-{AWS_REGION}-{AWS_ID}/renate-cifar10/",
        backend="sagemaker",  # run on SageMaker, select "local" to run this locally
        role=get_execution_role(),
        instance_count=1,
        instance_type="ml.g4dn.xlarge",
        max_num_trials_finished=100,
        scheduler="asha",  # run ASHA to optimize our hyperparameters
        # if you use a big instance with multiple GPUs you can multiple workers evaluating
        # configuration in parallel
        # n_workers=4,
        job_name="job-name",
    )

Once the training job terminates, the output will be available in the S3 bucket indicated in next_state_url. For more information about how to interpret the output, see Renate’s output.

To simulate an application where data are made available incrementally over time, after the first training job has been executed, it is possible to re-train the model on the second chunk of the dataset that we left intentionally untouched during the first training process.

To do this, it is sufficient to modifying the arguments passed to the run_training_job() function. In particular:

  1. select the second part of the datasets by setting chunk-id = 1.

2. load the model trained in the first training job by adding the state_url argument pointing to the same S3 location. In this case it will be useful to change the url for the next_state_url to avoid overwriting the old artefacts.

Note that in our example we specified requirements_file="requirements.txt" even if it is not necessary since the only dependency in the file is Renate itself. The only purpose of this is to show how additional dependencies can be added when needed.