renate.updaters.experimental.offline_er module#

class renate.updaters.experimental.offline_er.OfflineExperienceReplayLearner(loss_weight_new_data=None, **kwargs)[source]#

Bases: ReplayLearner

Experience Replay in the offline version.

The model will be trained on weighted mixture of losses computed on the new data and a replay buffer. In contrast to the online version, the buffer will only be updated after training has terminated.

Parameters:

loss_weight_new_data (Optional[float]) – The training loss will be a convex combination of the loss on the new data and the loss on the memory data. If a float (needs to be in [0, 1]) is given here, it will be used as the weight for the new data. If None, the weight will be set dynamically to N_t / sum([N_1, …, N_t]), where N_i denotes the size of task/chunk i and the current task is t.

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Called before a model update starts.

Return type:

None

train_dataloader()[source]#

Returns the dataloader for training the model.

Return type:

DataLoader

on_model_update_end()[source]#

Called right before a model update terminates.

Return type:

None

training_step(batch, batch_idx)[source]#

PyTorch Lightning function to return the training loss.

Return type:

Union[Tensor, Dict[str, Any]]

on_save_checkpoint(checkpoint)[source]#

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

checkpoint (Dict[str, Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Return type:

None

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

on_load_checkpoint(checkpoint)[source]#

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint (Dict[str, Any]) – Loaded checkpoint

Return type:

None

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

class renate.updaters.experimental.offline_er.OfflineExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight_new_data=None, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: SingleTrainingLoopUpdater