renate.updaters.experimental.gdumb module#

class renate.updaters.experimental.gdumb.GDumbLearner(memory_size, buffer_transform=None, buffer_target_transform=None, seed=0, **kwargs)[source]#

Bases: ReplayLearner

A Learner that implements the GDumb strategy.

Prabhu, Ameya, Philip HS Torr, and Puneet K. Dokania. “GDumb: A simple approach that questions our progress in continual learning.” ECCV, 2020.

It maintains a memory of previously observed data points and does the training after updating the buffer. Note that, the model is reinitialized before training on the buffer.

Parameters:
  • memory_size (int) – The maximum size of the memory.

  • buffer_transform (Optional[Callable]) – The transform to be applied to the data points in the memory.

  • buffer_target_transform (Optional[Callable]) – The transform to be applied to the targets in the memory.

  • seed (int) – A random seed.

on_load_checkpoint(checkpoint)[source]#

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint (Dict[str, Any]) – Loaded checkpoint

Return type:

None

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Called before a model update starts.

Return type:

None

train_dataloader()[source]#

Returns the dataloader for training the model.

Return type:

DataLoader

training_step(batch, batch_idx)[source]#

PyTorch Lightning function to return the training loss.

Return type:

Union[Tensor, Dict[str, Any]]

class renate.updaters.experimental.gdumb.GDumbModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: SingleTrainingLoopUpdater