renate.updaters.avalanche.model_updater module#
- class renate.updaters.avalanche.model_updater.AvalancheModelUpdater(*args, **kwargs)[source]#
Bases:
SingleTrainingLoopUpdater- update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Updates the model using the data passed as input.
- Parameters:
train_dataset¶ (
Dataset) – The training data.val_dataset¶ (
Optional[Dataset]) – The validation data.train_dataset_collate_fn¶ (
Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.val_dataset_collate_fn¶ (
Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.task_id¶ (
Optional[str]) – The task id.
- Return type:
- class renate.updaters.avalanche.model_updater.ExperienceReplayAvalancheModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
AvalancheModelUpdater
- class renate.updaters.avalanche.model_updater.ElasticWeightConsolidationModelUpdater(model, loss_fn, optimizer, ewc_lambda=0.4, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
AvalancheModelUpdater
- class renate.updaters.avalanche.model_updater.LearningWithoutForgettingModelUpdater(model, loss_fn, optimizer, alpha=1, temperature=2, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, seed=0, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
AvalancheModelUpdater
- class renate.updaters.avalanche.model_updater.ICaRLModelUpdater(model, loss_fn, optimizer, memory_size, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
AvalancheModelUpdater