renate.updaters.experimental.er module#
- class renate.updaters.experimental.er.BaseExperienceReplayLearner(components, loss_weight=1.0, ema_memory_update_gamma=1.0, loss_normalization=1, **kwargs)[source]#
Bases:
ReplayLearner
,ABC
A base implementation of experience replay.
It is designed for the online CL setting, where only one pass over each new chunk of data is allowed. The Learner maintains a Reservoir buffer. In the training step, it samples a batch of data from the memory and appends it to the batch of current-task data. At the end of the training step, the memory is updated.
- Parameters:
components¶ (
Dict
[str
,Component
]) – An ordered dictionary of components that are part of the experience replay learner.loss_weight¶ (
float
) – A scalar weight factor for the base loss function to trade it off with other loss functions added bycomponents
.ema_memory_update_gamma¶ (
float
) – The gamma used for exponential moving average to update the meta data with respect to the logits and intermediate representation, if there is some.loss_normalization¶ (
int
) – Whether to normalize the loss by the weights of all the components.
- on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Called before a model update starts.
- Return type:
None
- on_train_start()[source]#
PyTorch Lightning function to be run at the start of the training.
- Return type:
None
- training_step(batch, batch_idx)[source]#
PyTorch Lightning function to return the training loss.
- Return type:
Union
[Tensor
,Dict
[str
,Any
]]
- training_step_end(step_output)[source]#
PyTorch Lightning function to perform after the training step.
- Return type:
Union
[Tensor
,Dict
[str
,Any
]]
- on_train_batch_end(outputs, batch, batch_idx)[source]#
PyTorch Lightning function to perform after the training and optimizer step.
- Return type:
None
- class renate.updaters.experimental.er.ExperienceReplayLearner(alpha=1.0, **kwargs)[source]#
Bases:
BaseExperienceReplayLearner
This is the version of experience replay proposed in
Chaudhry, Arslan, et al. “On tiny episodic memories in continual learning.” arXiv preprint arXiv:1902.10486 (2019).
- Parameters:
alpha¶ (
float
) – The weight of the cross-entropy loss component applied to the memory samples.
- class renate.updaters.experimental.er.DarkExperienceReplayLearner(alpha=1.0, beta=1.0, **kwargs)[source]#
Bases:
ExperienceReplayLearner
A Learner that implements Dark Experience Replay.
Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara: Dark Experience for General Continual Learning: a Strong, Simple Baseline. NeurIPS 2020
- Parameters:
- class renate.updaters.experimental.er.PooledOutputDistillationExperienceReplayLearner(alpha=1.0, distillation_type='spatial', normalize=1, **kwargs)[source]#
Bases:
BaseExperienceReplayLearner
A Learner that implements Pooled Output Distillation.
Douillard, Arthur, et al. “Podnet: Pooled outputs distillation for small-tasks incremental learning.” European Conference on Computer Vision. Springer, Cham, 2020.
- Parameters:
alpha¶ (
float
) – Scaling value which scales the loss with respect to all intermediate representations.distillation_type¶ (
str
) – Which distillation type to apply with respect to the intermediate representation.normalize¶ (
bool
) – Whether to normalize both the current and cached features before computing the Frobenius norm.
- class renate.updaters.experimental.er.CLSExperienceReplayLearner(alpha=0.5, beta=0.1, stable_model_update_weight=0.999, plastic_model_update_weight=0.999, stable_model_update_probability=0.7, plastic_model_update_probability=0.9, **kwargs)[source]#
Bases:
BaseExperienceReplayLearner
A learner that implements a Complementary Learning Systems Based Experience Replay.
Arani, Elahe, Fahad Sarfraz, and Bahram Zonooz. “Learning fast, learning slow: A general continual learning method based on complementary learning system.” arXiv preprint arXiv:2201.12604 (2022).
- Parameters:
alpha¶ (
float
) – Scaling value for the cross-entropy loss.beta¶ (
float
) – Scaling value for the consistency loss.stable_model_update_weight¶ (
float
) – The starting weight for the exponential moving average to update the stable model copy.plastic_model_update_weight¶ (
float
) – The starting weight for the exponential moving average to update the plastic model copy.stable_model_update_probability¶ (
float
) – The probability to update the stable model copy.plastic_model_update_probability¶ (
float
) – The probability to update the plastic model copy.
- components(model, loss_fn, alpha=0.5, beta=0.1, plastic_model_update_weight=0.999, stable_model_update_weight=0.999, plastic_model_update_probability=0.9, stable_model_update_probability=0.7)[source]#
Returns the components of the learner.
This is a user-defined function that should return a dictionary of components.
- Return type:
Dict
[str
,Component
]
- class renate.updaters.experimental.er.SuperExperienceReplayLearner(der_alpha=1.0, der_beta=1.0, sp_shrink_factor=0.95, sp_sigma=0.001, cls_alpha=0.1, cls_stable_model_update_weight=0.999, cls_plastic_model_update_weight=0.999, cls_stable_model_update_probability=0.7, cls_plastic_model_update_probability=0.9, pod_alpha=1.0, pod_distillation_type='spatial', pod_normalize=1, ema_memory_update_gamma=1.0, **kwargs)[source]#
Bases:
BaseExperienceReplayLearner
A learner that implements a selected combination of methods.
- Parameters:
der_alpha¶ (
float
) – The weight of the mean squared error loss component between memorised logits and the current logits on the memory data.der_beta¶ (
float
) – The weight of the cross-entropy loss component between memorised targets and the current logits on the memory data.sp_shrink_factor¶ (
float
) – Shrinking value applied with respect to shrink and perturbation.sp_sigma¶ (
float
) – Standard deviation applied with respect to shrink and perturbation.cls_alpha¶ (
float
) – Scaling value for the consistency loss added to the base cross-entropy loss.cls_stable_model_update_weight¶ (
float
) – The starting weight for the exponential moving average to update the stable model copy.cls_plastic_model_update_weight¶ (
float
) – The starting weight for the exponential moving average to update the plastic model copy.cls_stable_model_update_probability¶ (
float
) – The probability to update the stable model copy.cls_plastic_model_update_probability¶ (
float
) – The probability to update the plastic model copy.pod_alpha¶ (
float
) – Scaling value which scales the loss with respect to all intermediate representations.pod_distillation_type¶ (
str
) – Which distillation type to apply with respect to the intermediate representation.pod_normalize¶ (
bool
) – Whether to normalize both the current and cached features before computing the Frobenius norm.ema_memory_update_gamma¶ (
float
) – The gamma used for exponential moving average to update the meta data with respect to the logits and intermediate representation, if there is some.
- components(model, loss_fn, der_alpha=1.0, der_beta=1.0, sp_shrink_factor=0.95, sp_sigma=0.001, cls_alpha=0.1, cls_stable_model_update_weight=0.999, cls_plastic_model_update_weight=0.999, cls_stable_model_update_probability=0.7, cls_plastic_model_update_probability=0.9, pod_alpha=1.0, pod_distillation_type='spatial', pod_normalize=1)[source]#
Returns the components of the learner.
This is a user-defined function that should return a dictionary of components.
- Return type:
Dict
[str
,Component
]
- class renate.updaters.experimental.er.ExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight=1.0, ema_memory_update_gamma=1.0, loss_normalization=1, alpha=1.0, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater
- class renate.updaters.experimental.er.DarkExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight=1.0, ema_memory_update_gamma=1.0, loss_normalization=1, alpha=1.0, beta=1.0, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater
- class renate.updaters.experimental.er.PooledOutputDistillationExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight=1.0, ema_memory_update_gamma=1.0, loss_normalization=1, alpha=1.0, distillation_type='spatial', normalize=1, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater
- class renate.updaters.experimental.er.CLSExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight=1.0, ema_memory_update_gamma=1.0, loss_normalization=1, alpha=0.5, beta=0.1, stable_model_update_weight=0.999, plastic_model_update_weight=0.999, stable_model_update_probability=0.7, plastic_model_update_probability=0.9, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater
- class renate.updaters.experimental.er.SuperExperienceReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight=1.0, ema_memory_update_gamma=1.0, loss_normalization=1, der_alpha=1.0, der_beta=1.0, sp_shrink_factor=0.95, sp_sigma=0.001, cls_alpha=0.1, cls_stable_model_update_weight=0.999, cls_plastic_model_update_weight=0.999, cls_stable_model_update_probability=0.7, cls_plastic_model_update_probability=0.9, pod_alpha=1.0, pod_distillation_type='spatial', pod_normalize=1, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater