renate.updaters.model_updater module#
- class renate.updaters.model_updater.SyneTuneCallback(val_enabled)[source]#
Bases:
Callback
Callback to report metrics to Syne Tune.
- Parameters:
val_enabled¶ (
bool
) – Whether validation was enabled in the Learner.
- on_train_epoch_end(trainer, pl_module)[source]#
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either: :rtype:
None
Implement
training_epoch_end
in theLightningModule
and access outputs via the module ORCache data across train batch hooks inside the callback implementation to post-process in this hook.
- class renate.updaters.model_updater.RenateModelCheckpoint(model, output_state_folder, val_enabled, metric=None, mode='min', use_syne_tune_callback=True)[source]#
Bases:
ModelCheckpoint
Callback to save Renate state after each epoch.
- Parameters:
model¶ (
RenateModule
) – Model to be saved when creating a checkpoint.output_state_folder¶ (
str
) – Checkpoint folder location.val_enabled¶ (
bool
) – Whether validation was enabled in the Learner. Forwarded toSyneTuneCallback
.metric¶ (
Optional
[str
]) – Monitored metric to decide when to write a new checkpoint. If no metric is provided or validation is not enabled, the latest model will be stored.mode¶ (
Literal
['min'
,'max'
]) –min
ormax
. Whether to minimize or maximize the monitoredmetric
.use_syne_tune_callback¶ (
bool
) – Whether to useSyneTuneCallback
.
- on_train_epoch_end(trainer, pl_module)[source]#
Save a checkpoint at the end of the training epoch.
- Return type:
None
- on_validation_epoch_end(trainer, pl_module)[source]#
Called when the val epoch ends.
- Return type:
None
- teardown(trainer, pl_module, stage)[source]#
Implements the separation of learner and model at the end of training.
There are two cases two handle.
1. If deepspeed is being used: The learner_state_path (which the checkpointing func) uses is a directory and not a file. This directory has sharded state_dicts (of model and optimizers), depending on which deepspeed stage is used. There are three steps here :rtype:
None
combine all the shards into one big state dict.
The learner_state_path is a dir (learner.ckpt/). This needs to be deleted first.
Write the combined state_dict as the learner.ckpt file as a single file.
Extract the state_dict element from the learner and save that as the model.ckpt.
2. If not deepspeed (say DDP or single device): The steps are much simpler.
Load the learner.ckpt and extract the state_dict element.
- Sanitize the extracted state_dict. Learner has the model in a _model attribute.So strip the first “_model.” from the keys of the state_dict.
Save the sanitized model to model.ckpt.
Case 2 is needs to be done even for Case 1 (step d). So teardown is a recursive call in Case 1 which automatically goes to Case 2 as learner.ckpt is file now.
- class renate.updaters.model_updater.ModelUpdater(model, loss_fn, optimizer, learner_class, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
ABC
Updates a learner using the data provided.
- Parameters:
model¶ (
RenateModule
) – The potentially pretrained model to be updated with new data.learner_class¶ (
Type
[Learner
]) – Class of the learner to be used for model update.learner_kwargs¶ (
Optional
[Dict
[str
,Any
]]) – Arguments either used for creating a new learner (no previous state available) or replace current arguments of the learner.input_state_folder¶ (
Optional
[str
]) – Folder used by Renate to store files for current state.output_state_folder¶ (
Optional
[str
]) – Folder used by Renate to store files for next state.max_epochs¶ (
int
) – The maximum number of epochs used to train the model. For comparability between methods, epochs are interpreted as “finetuning-equivalent”. That is, one epoch is defined aslen(current_task_dataset) / batch_size
update steps.train_transform¶ (
Optional
[Callable
]) – The transformation applied during training.train_target_transform¶ (
Optional
[Callable
]) – The target transformation applied during testing.test_transform¶ (
Optional
[Callable
]) – The transformation at test time.test_target_transform¶ (
Optional
[Callable
]) – The target transformation at test time.buffer_transform¶ (
Optional
[Callable
]) – Augmentations applied to the input data coming from the memory. Not all updaters require this. If required but not passed,transform
will be used.buffer_target_transform¶ (
Optional
[Callable
]) – Transformations applied to the target. Not all updaters require this. If required but not passed,target_transform
will be used.metric¶ (
Optional
[str
]) – Monitored metric to decide when to write a new checkpoint or early-stop the optimization. If no metric is provided, the latest model will be stored.mode¶ (
Literal
['min'
,'max'
]) –min
ormax
. Whether to minimize or maximize the monitoredmetric
.logged_metrics¶ (
Optional
[Dict
[str
,Metric
]]) – Metrics logged additional to the default ones.early_stopping_enabled¶ (
bool
) – Enables the early stopping of the optimization.logger¶ (
Logger
) – Logger used by PyTorch Lightning to log intermediate results.accelerator¶ (
Literal
['auto'
,'cpu'
,'gpu'
,'tpu'
]) – Accelerator used by PyTorch Lightning to train the model.devices¶ (
Optional
[int
]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch theauto_device_count
from theaccelerator
.deterministic_trainer¶ (
bool
) – When set to True makes the output of the training deterministic. The value is passed to the trainer as described here.gradient_clip_val¶ (
Optional
[float
]) – Gradient clipping value used in PyTorch Lightning. Defaults to not clipping by using a value of None.gradient_clip_algorithm¶ (
Optional
[str
]) – Method to clip gradients (norm or value) used in PyTorch Lightning.
- abstract update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Updates the model using the data passed as input.
- Parameters:
train_dataset¶ (
Dataset
) – The training data.val_dataset¶ (
Optional
[Dataset
]) – The validation data.train_dataset_collate_fn¶ (
Optional
[Callable
]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.val_dataset_collate_fn¶ (
Optional
[Callable
]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.task_id¶ (
Optional
[str
]) – The task id.
- Return type:
None
- class renate.updaters.model_updater.SingleTrainingLoopUpdater(model, loss_fn, optimizer, learner_class, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
ModelUpdater
Simple ModelUpdater which requires a single learner only to update the model.
- update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Updates the model using the data passed as input.
- Parameters:
train_dataset¶ (
Dataset
) – The training data.val_dataset¶ (
Optional
[Dataset
]) – The validation data.train_dataset_collate_fn¶ (
Optional
[Callable
]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.val_dataset_collate_fn¶ (
Optional
[Callable
]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.task_id¶ (
Optional
[str
]) – The task id.
- Return type: