renate.updaters.experimental.repeated_distill module#

renate.updaters.experimental.repeated_distill.double_distillation_loss(predicted_logits, target_logits)[source]#

Double distillation loss, where target logits are normalized across the class-dimension.

This normalization is useful when distilling from multiple teachers and was proposed in

TODO: Fix citation once we agreed on a format. Zhang, Junting, et al. “Class-incremental learning via deep model consolidation.” Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2020.

Parameters:
  • predicted_logits (Tensor) – Logit predictions of the student model, size (B, C), where B is the batch size and C is the number of classes.

  • target_logits (Tensor) – Logits obtained from the teacher model(s), same size (B, C).

Return type:

Tensor

Returns:

A tensor of size (B,) containing the loss values for each datapoint in the batch.

renate.updaters.experimental.repeated_distill.extract_logits(model, dataset, batch_size, task_id='default_task')[source]#

Extracts logits from a model for each point in a dataset.

Parameters:
  • model (Module) – The model. model.get_logits(X) is assumed to return logits.

  • dataset (Dataset) – The dataset.

  • batch_size (int) – Batch size used to iterate over the dataset.

  • task_id (Optional[str]) – Task id to be used, e.g., to select the output head.

Return type:

Tensor

Returns:

A tensor logits of shape (N, C) where N is the length of the dataset and C is the output dimension of model, i.e., the number of classes.

class renate.updaters.experimental.repeated_distill.RepeatedDistillationModelUpdater(model, loss_fn, optimizer, memory_size, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, input_state_folder=None, output_state_folder=None, max_epochs=50, metric=None, mode='min', logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', logged_metrics=None, seed=None, early_stopping_enabled=False, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: ModelUpdater

Repeated Distillation (RD) is inspired by Deep Model Consolidation (DMC), which was proposed in

TODO: Fix citation once we agreed on a format. Zhang, Junting, et al. “Class-incremental learning via deep model consolidation.” Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2020.

The idea underlying RD is the following: Given a new task/batch, a new model copy is trained from scratch on that data. Subsequently, this expert model is consolidated with the previous model state via knowledge distillation. The resulting consolidated model state is maintained, whereas the expert model is discarded.

Our variant differs from the original algorithm in two ways:
  • The original algorithm is designed specifically for the class-incremental setting, where each new task introduces one or more novel classes. This variant is designed for the general continual learning setting with a pre-determined number of classes.

  • The original method is supposed to be memory-free and uses auxiliary data for the model consolidation phase. Our variant performs knowledge distillation over a memory

update(train_dataset, val_dataset=None, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Updates the model using the data passed as input.

Parameters:
  • train_dataset (Dataset) – The training data.

  • val_dataset (Optional[Dataset]) – The validation data.

  • train_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data.

  • val_dataset_collate_fn (Optional[Callable]) – collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data.

  • task_id (Optional[str]) – The task id.

Return type:

RenateModule

class renate.updaters.experimental.repeated_distill.RepeatedDistillationLearner(**kwargs)[source]#

Bases: ReplayLearner

A learner performing distillation.

update_expert_logits(new_expert_logits)[source]#

Update expert logits.

Return type:

None

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

Called before a model update starts.

Return type:

None

train_dataloader()[source]#

Returns the dataloader for training the model.

Return type:

DataLoader

on_model_update_end()[source]#

Called right before a model update terminates.

Return type:

None

training_step(batch, batch_idx)[source]#

PyTorch Lightning function to return the training loss.

Return type:

Union[Tensor, Dict[str, Any]]