renate.updaters.model_updater module#
- class renate.updaters.model_updater.SyneTuneCallback(val_enabled)[source]#
Bases:
CallbackCallback to report metrics to Syne Tune.
- Parameters:
val_enabled¶ (
bool) – Whether validation was enabled in the Learner.
- on_train_epoch_end(trainer, pl_module)[source]#
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either: :rtype:
NoneImplement
training_epoch_endin theLightningModuleand access outputs via the module ORCache data across train batch hooks inside the callback implementation to post-process in this hook.
- class renate.updaters.model_updater.RenateModelCheckpoint(model, output_state_folder, val_enabled, metric=None, mode='min', use_syne_tune_callback=True)[source]#
Bases:
ModelCheckpointCallback to save Renate state after each epoch.
- Parameters:
model¶ (
RenateModule) – Model to be saved when creating a checkpoint.output_state_folder¶ (
str) – Checkpoint folder location.val_enabled¶ (
bool) – Whether validation was enabled in the Learner. Forwarded toSyneTuneCallback.metric¶ (
Optional[str]) – Monitored metric to decide when to write a new checkpoint. If no metric is provided or validation is not enabled, the latest model will be stored.mode¶ (
Literal['min','max']) –minormax. Whether to minimize or maximize the monitoredmetric.use_syne_tune_callback¶ (
bool) – Whether to useSyneTuneCallback.
- on_train_epoch_end(trainer, pl_module)[source]#
Save a checkpoint at the end of the training epoch.
- Return type:
None
- on_validation_epoch_end(trainer, pl_module)[source]#
Called when the val epoch ends.
- Return type:
None
- teardown(trainer, pl_module, stage)[source]#
Implements the separation of learner and model at the end of training.
There are two cases two handle.
1. If deepspeed is being used: The learner_state_path (which the checkpointing func) uses is a directory and not a file. This directory has sharded state_dicts (of model and optimizers), depending on which deepspeed stage is used. There are three steps here :rtype:
Nonecombine all the shards into one big state dict.
The learner_state_path is a dir (learner.ckpt/). This needs to be deleted first.
Write the combined state_dict as the learner.ckpt file as a single file.
Extract the state_dict element from the learner and save that as the model.ckpt.
2. If not deepspeed (say DDP or single device): The steps are much simpler.
Load the learner.ckpt and extract the state_dict element.
- Sanitize the extracted state_dict. Learner has the model in a _model attribute.So strip the first “_model.” from the keys of the state_dict.
Save the sanitized model to model.ckpt.
Case 2 is needs to be done even for Case 1 (step d). So teardown is a recursive call in Case 1 which automatically goes to Case 2 as learner.ckpt is file now.
- class renate.updaters.model_updater.ModelUpdater(model, loss_fn, optimizer, learner_class, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
ABCUpdates a learner using the data provided.
- Parameters:
model¶ (
RenateModule) – The potentially pretrained model to be updated with new data.learner_class¶ (
Type[Learner]) – Class of the learner to be used for model update.learner_kwargs¶ (
Optional[Dict[str,Any]]) – Arguments either used for creating a new learner (no previous state available) or replace current arguments of the learner.input_state_folder¶ (
Optional[str]) – Folder used by Renate to store files for current state.output_state_folder¶ (
Optional[str]) – Folder used by Renate to store files for next state.max_epochs¶ (
int) – The maximum number of epochs used to train the model. For comparability between methods, epochs are interpreted as “finetuning-equivalent”. That is, one epoch is defined aslen(current_task_dataset) / batch_sizeupdate steps.train_transform¶ (
Optional[Callable]) – The transformation applied during training.train_target_transform¶ (
Optional[Callable]) – The target transformation applied during testing.test_transform¶ (
Optional[Callable]) – The transformation at test time.test_target_transform¶ (
Optional[Callable]) – The target transformation at test time.buffer_transform¶ (
Optional[Callable]) – Augmentations applied to the input data coming from the memory. Not all updaters require this. If required but not passed,transformwill be used.buffer_target_transform¶ (
Optional[Callable]) – Transformations applied to the target. Not all updaters require this. If required but not passed,target_transformwill be used.metric¶ (
Optional[str]) – Monitored metric to decide when to write a new checkpoint or early-stop the optimization. If no metric is provided, the latest model will be stored.mode¶ (
Literal['min','max']) –minormax. Whether to minimize or maximize the monitoredmetric.logged_metrics¶ (
Optional[Dict[str,Metric]]) – Metrics logged additional to the default ones.early_stopping_enabled¶ (
bool) – Enables the early stopping of the optimization.logger¶ (
Logger) – Logger used by PyTorch Lightning to log intermediate results.accelerator¶ (
Literal['auto','cpu','gpu','tpu']) – Accelerator used by PyTorch Lightning to train the model.devices¶ (
Optional[int]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch theauto_device_countfrom theaccelerator.deterministic_trainer¶ (
bool) – When set to True makes the output of the training deterministic. The value is passed to the trainer as described here.gradient_clip_val¶ (
Optional[float]) – Gradient clipping value used in PyTorch Lightning. Defaults to not clipping by using a value of None.gradient_clip_algorithm¶ (
Optional[str]) – Method to clip gradients (norm or value) used in PyTorch Lightning.
- abstract update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Updates the model using the data passed as input.
- Parameters:
train_dataset¶ (
Dataset) – The training data.val_dataset¶ (
Optional[Dataset]) – The validation data.train_dataset_collate_fn¶ (
Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.val_dataset_collate_fn¶ (
Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.task_id¶ (
Optional[str]) – The task id.
- Return type:
None
- class renate.updaters.model_updater.SingleTrainingLoopUpdater(model, loss_fn, optimizer, learner_class, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
ModelUpdaterSimple ModelUpdater which requires a single learner only to update the model.
- update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Updates the model using the data passed as input.
- Parameters:
train_dataset¶ (
Dataset) – The training data.val_dataset¶ (
Optional[Dataset]) – The validation data.train_dataset_collate_fn¶ (
Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.val_dataset_collate_fn¶ (
Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.task_id¶ (
Optional[str]) – The task id.
- Return type: