renate.updaters.learner module#

class renate.updaters.learner.RenateLightningModule(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#

Bases: LightningModule, ABC

Base class for LightningModules, which implement metric logging and basic training logic.

The RenateLightningModule is a LightningModule, but provides additional hook functions called by ModelUpdater. These hooks are:

  • on_model_update_start, which is called in the beginning of a

    model update. We expect this to return train and (optionally) validation data loader(s).

  • on_model_update_end, which is called in the end of a model update.

Parameters:
  • model (RenateModule) – The model to be trained.

  • optimizer (Callable[[List[Parameter]], Optimizer]) – Partial optimizer used to create an optimizer by passing the model parameters.

  • learning_rate_scheduler (Optional[Callable[[Optimizer], _LRScheduler]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.

  • learning_rate_scheduler_interval (Literal['epoch', 'step']) – When to update the learning rate scheduler. Options: epoch and step.

  • batch_size (int) – Training batch size.

  • logged_metrics (Optional[Dict[str, Metric]]) – Metrics logged additional to the default ones.

  • seed (int) – See renate.models.utils.get_generator().

  • mask_unused_classes (bool) – Flag to use if logits corresponding to unused classes are to be ignored in the loss computation. Possibly useful for class incremental learning.

is_logged_metric(metric_name)[source]#

Returns True if there is a metric with name metric_name.

Return type:

bool

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Return type:

None

train_dataloader()[source]#

Returns the dataloader for training the model.

Return type:

DataLoader

val_dataloader()[source]#

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return type:

Optional[DataLoader]

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

on_model_update_end()[source]#

Called right before a model update terminates.

Return type:

None

forward(inputs, task_id=None)[source]#

Forward pass of the model.

Return type:

Tensor

training_step_unpack_batch(batch)[source]#
Return type:

Tuple[Any, Any]

training_step(batch, batch_idx)[source]#

PyTorch Lightning function to return the training loss.

Return type:

Union[Tensor, Dict[str, Any]]

training_step_end(step_output)[source]#

PyTorch Lightning function to perform after the training step.

Return type:

Union[Tensor, Dict[str, Any]]

training_epoch_end(outputs)[source]#

PyTorch Lightning function to run at the end of training epoch.

Return type:

None

validation_step_unpack_batch(batch)[source]#
Return type:

Tuple[Any, Any]

validation_step(batch, batch_idx)[source]#

PyTorch Lightning function to estimate validation metrics.

Return type:

None

validation_epoch_end(outputs)[source]#

PyTorch Lightning function to run at the end of validation epoch.

Return type:

None

configure_optimizers()[source]#

PyTorch Lightning function to create optimizers and learning rate schedulers.

Return type:

Union[Optimizer, Tuple[List[Optimizer], List[Dict[str, Any]]]]

class renate.updaters.learner.Learner(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#

Bases: RenateLightningModule, ABC

Base class for Learners, which encapsulate the core CL methodologies.

The Learner is a LightningModule, but provides additional hook functions called by ModelUpdater. These hooks are:

  • Learner.on_model_update_start, which is called in the beginning of a

    model update. We expect this to return train and (optionally) validation data loader(s).

  • Learner.on_model_update_end, which is called in the end of a model update.

This base class implements a basic training loop without any mechanism to counteract forgetting.

Parameters:
  • model (RenateModule) – The model to be trained.

  • optimizer (Callable[[List[Parameter]], Optimizer]) – Partial optimizer used to create an optimizer by passing the model parameters.

  • learning_rate_scheduler (Optional[Callable[[Optimizer], _LRScheduler]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.

  • learning_rate_scheduler_interval (Literal['epoch', 'step']) – When to update the learning rate scheduler. Options: epoch and step.

  • batch_size (int) – Training batch size.

  • train_transform (Optional[Callable]) – The transformation applied during training.

  • train_target_transform (Optional[Callable]) – The target transformation applied during testing.

  • test_transform (Optional[Callable]) – The transformation at test time.

  • test_target_transform (Optional[Callable]) – The target transformation at test time.

  • logged_metrics (Optional[Dict[str, Metric]]) – Metrics logged additional to the default ones.

  • seed (int) – See renate.models.utils.get_generator().

on_save_checkpoint(checkpoint)[source]#

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

checkpoint (Dict[str, Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Return type:

None

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

on_load_checkpoint(checkpoint)[source]#

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint (Dict[str, Any]) – Loaded checkpoint

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

save(output_state_dir)[source]#
Return type:

None

load(input_state_dir)[source]#
Return type:

None

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Return type:

None

train_dataloader()[source]#

Returns the dataloader for training the model.

Return type:

DataLoader

val_dataloader()[source]#

Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be reloaded unless you set reload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Return type:

Optional[DataLoader]

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying validation samples.

Examples:

def val_dataloader(self):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (1.0,))])
    dataset = MNIST(root='/path/to/mnist/', train=False,
                    transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=self.batch_size,
        shuffle=False
    )

    return loader

# can also return multiple dataloaders
def val_dataloader(self):
    return [loader_a, loader_b, ..., loader_n]

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Note

In the case where you return multiple validation dataloaders, the validation_step() will have an argument dataloader_idx which matches the order here.

validation_step_unpack_batch(batch)[source]#
Return type:

Tuple[Union[Tensor, Tuple[Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]], Dict[str, Union[Tensor, Tuple[NestedTensors], Dict[str, NestedTensors]]]], Any]

class renate.updaters.learner.ReplayLearner(memory_size, batch_size=32, batch_memory_frac=0.5, buffer_transform=None, buffer_target_transform=None, seed=0, **kwargs)[source]#

Bases: Learner, ABC

Base class for Learners which use a buffer to store data and reuse it in future updates.

Parameters:
  • memory_size (int) – The maximum size of the memory.

  • batch_memory_frac (float) – Fraction of the batch that is sampled from rehearsal memory.

  • buffer_transform (Optional[Callable]) – The transformation to be applied to the memory buffer data samples.

  • buffer_target_transform (Optional[Callable]) – The target transformation to be applied to the memory buffer target samples.

  • seed (int) – See renate.models.utils.get_generator().

on_save_checkpoint(checkpoint)[source]#

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

checkpoint (Dict[str, Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Return type:

None

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

on_load_checkpoint(checkpoint)[source]#

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint (Dict[str, Any]) – Loaded checkpoint

Return type:

None

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

save(output_state_dir)[source]#
Return type:

None

load(input_state_dir)[source]#
Return type:

None

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
Return type:

None