renate.updaters.experimental.offline_er module#
- class renate.updaters.experimental.offline_er.OfflineExperienceReplayLearner(loss_weight_new_data=None, **kwargs)[source]#
Bases:
ReplayLearner
Experience Replay in the offline version.
The model will be trained on weighted mixture of losses computed on the new data and a replay buffer. In contrast to the online version, the buffer will only be updated after training has terminated.
- Parameters:
loss_weight_new_data¶ (
Optional
[float
]) – The training loss will be a convex combination of the loss on the new data and the loss on the memory data. If a float (needs to be in [0, 1]) is given here, it will be used as the weight for the new data. IfNone
, the weight will be set dynamically toN_t / sum([N_1, ..., N_t])
, whereN_i
denotes the size of task/chunki
and the current task ist
.
- on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Called before a model update starts.
- Return type:
None
- training_step(batch, batch_idx)[source]#
PyTorch Lightning function to return the training loss.
- Return type:
Union
[Tensor
,Dict
[str
,Any
]]
- on_save_checkpoint(checkpoint)[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.- Return type:
None
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint)[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()
this is your chance to restore this.- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – Loaded checkpoint- Return type:
None
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- class renate.updaters.experimental.offline_er.OfflineExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight_new_data=None, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater