renate.updaters.experimental.joint module#
- class renate.updaters.experimental.joint.JointLearner(**kwargs)[source]#
Bases:
Learner
A Learner that implements the Joint strategy.
This is a simple strategy that trains the model on all previously observed data. Each time a new chunk of data is observed, the model is reinitialized and retrained on all the previously observed data and the new chunk of data. The buffer holding the previous data is updated with the new chunk of data.
- on_save_checkpoint(checkpoint)[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.- Return type:
None
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint)[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()
this is your chance to restore this.- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – Loaded checkpoint- Return type:
None
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- class renate.updaters.experimental.joint.JointModelUpdater(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater