renate.updaters.model_updater module#

class renate.updaters.model_updater.SyneTuneCallback(val_enabled)[source]#

Bases: Callback

Callback to report metrics to Syne Tune.

Parameters:

val_enabled (bool) – Whether validation was enabled in the Learner.

on_train_epoch_end(trainer, pl_module)[source]#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either: :rtype: None

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

on_validation_epoch_end(trainer, pl_module)[source]#

Called when the val epoch ends.

Return type:

None

on_train_start(trainer, pl_module)[source]#

Called when the train begins.

Return type:

None

class renate.updaters.model_updater.RenateModelCheckpoint(model, output_state_folder, val_enabled, metric=None, mode='min', use_syne_tune_callback=True)[source]#

Bases: ModelCheckpoint

Callback to save Renate state after each epoch.

Parameters:
  • model (RenateModule) – Model to be saved when creating a checkpoint.

  • output_state_folder (str) – Checkpoint folder location.

  • val_enabled (bool) – Whether validation was enabled in the Learner. Forwarded to SyneTuneCallback.

  • metric (Optional[str]) – Monitored metric to decide when to write a new checkpoint. If no metric is provided or validation is not enabled, the latest model will be stored.

  • mode (Literal['min', 'max']) – min or max. Whether to minimize or maximize the monitored metric.

  • use_syne_tune_callback (bool) – Whether to use SyneTuneCallback.

on_train_epoch_end(trainer, pl_module)[source]#

Save a checkpoint at the end of the training epoch.

Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]#

Called when the val epoch ends.

Return type:

None

teardown(trainer, pl_module, stage)[source]#

Implements the separation of learner and model at the end of training.

There are two cases two handle.

1. If deepspeed is being used: The learner_state_path (which the checkpointing func) uses is a directory and not a file. This directory has sharded state_dicts (of model and optimizers), depending on which deepspeed stage is used. There are three steps here :rtype: None

  1. combine all the shards into one big state dict.

  2. The learner_state_path is a dir (learner.ckpt/). This needs to be deleted first.

  3. Write the combined state_dict as the learner.ckpt file as a single file.

  4. Extract the state_dict element from the learner and save that as the model.ckpt.

2. If not deepspeed (say DDP or single device): The steps are much simpler.

  1. Load the learner.ckpt and extract the state_dict element.

  2. Sanitize the extracted state_dict. Learner has the model in a _model attribute.
    So strip the first “_model.” from the keys of the state_dict.
  3. Save the sanitized model to model.ckpt.

Case 2 is needs to be done even for Case 1 (step d). So teardown is a recursive call in Case 1 which automatically goes to Case 2 as learner.ckpt is file now.

on_exception(trainer, pl_module, exception)[source]#

Called when any trainer execution is interrupted by an exception.

Return type:

None

on_fit_end(trainer, pl_module)[source]#

Called when fit ends.

Return type:

None

class renate.updaters.model_updater.ModelUpdater(model, loss_fn, optimizer, learner_class, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: ABC

Updates a learner using the data provided.

Parameters:
  • model (RenateModule) – The potentially pretrained model to be updated with new data.

  • learner_class (Type[Learner]) – Class of the learner to be used for model update.

  • learner_kwargs (Optional[Dict[str, Any]]) – Arguments either used for creating a new learner (no previous state available) or replace current arguments of the learner.

  • input_state_folder (Optional[str]) – Folder used by Renate to store files for current state.

  • output_state_folder (Optional[str]) – Folder used by Renate to store files for next state.

  • max_epochs (int) – The maximum number of epochs used to train the model. For comparability between methods, epochs are interpreted as “finetuning-equivalent”. That is, one epoch is defined as len(current_task_dataset) / batch_size update steps.

  • train_transform (Optional[Callable]) – The transformation applied during training.

  • train_target_transform (Optional[Callable]) – The target transformation applied during testing.

  • test_transform (Optional[Callable]) – The transformation at test time.

  • test_target_transform (Optional[Callable]) – The target transformation at test time.

  • buffer_transform (Optional[Callable]) – Augmentations applied to the input data coming from the memory. Not all updaters require this. If required but not passed, transform will be used.

  • buffer_target_transform (Optional[Callable]) – Transformations applied to the target. Not all updaters require this. If required but not passed, target_transform will be used.

  • metric (Optional[str]) – Monitored metric to decide when to write a new checkpoint or early-stop the optimization. If no metric is provided, the latest model will be stored.

  • mode (Literal['min', 'max']) – min or max. Whether to minimize or maximize the monitored metric.

  • logged_metrics (Optional[Dict[str, Metric]]) – Metrics logged additional to the default ones.

  • early_stopping_enabled (bool) – Enables the early stopping of the optimization.

  • logger (Logger) – Logger used by PyTorch Lightning to log intermediate results.

  • accelerator (Literal['auto', 'cpu', 'gpu', 'tpu']) – Accelerator used by PyTorch Lightning to train the model.

  • devices (Optional[int]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch the auto_device_count from the accelerator.

  • deterministic_trainer (bool) – When set to True makes the output of the training deterministic. The value is passed to the trainer as described here.

  • gradient_clip_val (Optional[float]) – Gradient clipping value used in PyTorch Lightning. Defaults to not clipping by using a value of None.

  • gradient_clip_algorithm (Optional[str]) – Method to clip gradients (norm or value) used in PyTorch Lightning.

abstract update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Updates the model using the data passed as input.

Parameters:
  • train_dataset (Dataset) – The training data.

  • val_dataset (Optional[Dataset]) – The validation data.

  • train_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.

  • val_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.

  • task_id (Optional[str]) – The task id.

Return type:

None

class renate.updaters.model_updater.SingleTrainingLoopUpdater(model, loss_fn, optimizer, learner_class, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: ModelUpdater

Simple ModelUpdater which requires a single learner only to update the model.

update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Updates the model using the data passed as input.

Parameters:
  • train_dataset (Dataset) – The training data.

  • val_dataset (Optional[Dataset]) – The validation data.

  • train_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.

  • val_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.

  • task_id (Optional[str]) – The task id.

Return type:

RenateModule