renate.updaters.avalanche.model_updater module#

class renate.updaters.avalanche.model_updater.AvalancheModelUpdater(*args, **kwargs)[source]#

Bases: SingleTrainingLoopUpdater

update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Updates the model using the data passed as input.

Parameters:
  • train_dataset (Dataset) – The training data.

  • val_dataset (Optional[Dataset]) – The validation data.

  • train_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.

  • val_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.

  • task_id (Optional[str]) – The task id.

Return type:

RenateModule

class renate.updaters.avalanche.model_updater.ExperienceReplayAvalancheModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: AvalancheModelUpdater

class renate.updaters.avalanche.model_updater.ElasticWeightConsolidationModelUpdater(model, loss_fn, optimizer, ewc_lambda=0.4, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: AvalancheModelUpdater

class renate.updaters.avalanche.model_updater.LearningWithoutForgettingModelUpdater(model, loss_fn, optimizer, alpha=1, temperature=2, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, seed=0, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: AvalancheModelUpdater

class renate.updaters.avalanche.model_updater.ICaRLModelUpdater(model, loss_fn, optimizer, memory_size, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: AvalancheModelUpdater