renate.updaters.experimental.offline_er module#

class renate.updaters.experimental.offline_er.OfflineExperienceReplayLearner(loss_weight_new_data=None, **kwargs)[source]#

Bases: ReplayLearner

Experience Replay in the offline version.

The model will be trained on weighted mixture of losses computed on the new data and a replay buffer. In contrast to the online version, the buffer will only be updated after training has terminated.


loss_weight_new_data (Optional[float]) – The training loss will be a convex combination of the loss on the new data and the loss on the memory data. If a float (needs to be in [0, 1]) is given here, it will be used as the weight for the new data. If None, the weight will be set dynamically to N_t / sum([N_1, ..., N_t]), where N_i denotes the size of task/chunk i and the current task is t.

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Called before a model update starts.

Return type:



Returns the dataloader for training the model.

Return type:



Called right before a model update terminates.

Return type:


training_step(batch, batch_idx)[source]#

PyTorch Lightning function to return the training loss.

Return type:

Union[Tensor, Dict[str, Any]]


Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.


checkpoint (Dict[str, Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Return type:



def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object


Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.


Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.


checkpoint (Dict[str, Any]) – Loaded checkpoint

Return type:



def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']


Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

class renate.updaters.experimental.offline_er.OfflineExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight_new_data=None, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: SingleTrainingLoopUpdater