Using Renate in a Custom Training Script#

Usually, we use Renate by writing a and launching training jobs via the run_training_job() function. In this example, we demonstrate how to write your own training script and use renate in a functional way. This can be useful, e.g., for debugging new components.

Here, we use Renate to fine-tune a pretrained Transformer model on a sequence classification dataset. First, we create the model and a loss function. Since this is a static model, we simply wrap it using the RenateWrapper class. Recall that loss functions should produce one loss value per input example (reduction="none" for PyTorch’s built-in losses), as explained in Loss Definition.

transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2, return_dict=False
model = RenateWrapper(transformer_model)
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

Next, we prepare the dataset on which we want to fine-tune the model. Here, we use the HuggingFaceTextDataModule to load the "imdb" dataset from the Hugging Face hub. This will also take care of tokenization for us, if we pass it the corresponding tokenizer.

tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
data_module = HuggingFaceTextDataModule(
    "data", dataset_name="rotten_tomatoes", tokenizer=tokenizer, val_size=0.2
data_module.prepare_data()  # For multi-GPU, call only on rank 0.

Now we can instantiate a ModelUpdater to perform the training. Since we just want to fine-tune the model on a single dataset here, we use the FineTuningModelUpdater. We pass our model as well as training details, such as the optimizer to use and its hyperparameters. The model updater also receives all options related to distributed training, as explained in Support for training large models. Once the model updater is created, we initiate the training by calling its update() method and passing training and (optionally) validation datasets.

optimizer = partial(Adam, learning_rate=3e-4)
updater = FineTuningModelUpdater(
updater.update(data_module.train_data(), data_module.val_data())

Once the training is terminated, your model is ready to deploy. Here, we just save its weights for later use using standard PyTorch functionality., "")