renate.updaters.experimental.repeated_distill module#
- renate.updaters.experimental.repeated_distill.double_distillation_loss(predicted_logits, target_logits)[source]#
Double distillation loss, where target logits are normalized across the class-dimension.
This normalization is useful when distilling from multiple teachers and was proposed in
TODO: Fix citation once we agreed on a format. Zhang, Junting, et al. “Class-incremental learning via deep model consolidation.” Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2020.
- Parameters:
- Return type:
Tensor
- Returns:
A tensor of size
(B,)
containing the loss values for each datapoint in the batch.
- renate.updaters.experimental.repeated_distill.extract_logits(model, dataset, batch_size, task_id='default_task')[source]#
Extracts logits from a model for each point in a dataset.
- Parameters:
- Return type:
Tensor
- Returns:
A tensor
logits
of shape(N, C)
whereN
is the length of the dataset andC
is the output dimension ofmodel
, i.e., the number of classes.
- class renate.updaters.experimental.repeated_distill.RepeatedDistillationModelUpdater(model, loss_fn, optimizer, memory_size, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, input_state_folder=None, output_state_folder=None, max_epochs=50, metric=None, mode='min', logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', logged_metrics=None, seed=None, early_stopping_enabled=False, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
ModelUpdater
Repeated Distillation (RD) is inspired by Deep Model Consolidation (DMC), which was proposed in
TODO: Fix citation once we agreed on a format. Zhang, Junting, et al. “Class-incremental learning via deep model consolidation.” Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2020.
The idea underlying RD is the following: Given a new task/batch, a new model copy is trained from scratch on that data. Subsequently, this expert model is consolidated with the previous model state via knowledge distillation. The resulting consolidated model state is maintained, whereas the expert model is discarded.
- Our variant differs from the original algorithm in two ways:
The original algorithm is designed specifically for the class-incremental setting, where each new task introduces one or more novel classes. This variant is designed for the general continual learning setting with a pre-determined number of classes.
The original method is supposed to be memory-free and uses auxiliary data for the model consolidation phase. Our variant performs knowledge distillation over a memory
- update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Updates the model using the data passed as input.
- Parameters:
train_dataset¶ (
Dataset
) – The training data.val_dataset¶ (
Optional
[Dataset
]) – The validation data.train_dataset_collate_fn¶ (
Optional
[Callable
]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.val_dataset_collate_fn¶ (
Optional
[Callable
]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.task_id¶ (
Optional
[str
]) – The task id.
- Return type:
- class renate.updaters.experimental.repeated_distill.RepeatedDistillationLearner(**kwargs)[source]#
Bases:
ReplayLearner
A learner performing distillation.