Using Renate in a Custom Training Script#
Usually, we use Renate by writing a renate_config.py
and launching training jobs via the
run_training_job()
function. In this example, we demonstrate how
to write your own training script and use renate in a functional way. This can be useful, e.g., for
debugging new components.
Here, we use Renate to fine-tune a pretrained Transformer model on a sequence classification
dataset. First, we create the model and a loss function. Since this is a static model, we simply
wrap it using the RenateWrapper
class. Recall that loss
functions should produce one loss value per input example (reduction="none"
for PyTorch’s
built-in losses), as explained in Loss Definition.
transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, return_dict=False
)
model = RenateWrapper(transformer_model)
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
Next, we prepare the dataset on which we want to fine-tune the model. Here, we use the
HuggingFaceTextDataModule
to load the
"imdb"
dataset from the Hugging Face hub. This will also take care of tokenization for us,
if we pass it the corresponding tokenizer.
tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
data_module = HuggingFaceTextDataModule(
"data", dataset_name="rotten_tomatoes", tokenizer=tokenizer, val_size=0.2
)
data_module.prepare_data() # For multi-GPU, call only on rank 0.
data_module.setup()
Now we can instantiate a ModelUpdater
to perform the
training. Since we just want to fine-tune the model on a single dataset here, we use the
FineTuningModelUpdater
. We pass our
model
as well as training details, such as the optimizer to use and its hyperparameters.
The model updater also receives all options related to distributed training, as explained in
Support for training large models.
Once the model updater is created, we initiate the training by calling its
update()
method and passing training and
(optionally) validation datasets.
optimizer = partial(Adam, learning_rate=3e-4)
updater = FineTuningModelUpdater(
model,
loss_fn,
optimizer=optimizer,
batch_size=32,
max_epochs=3,
input_state_folder=None,
output_state_folder="renate_output",
)
updater.update(data_module.train_data(), data_module.val_data())
Once the training is terminated, your model is ready to deploy. Here, we just save its weights for later use using standard PyTorch functionality.
torch.save(model.state_dict(), "model_weights.pt")