renate.updaters.learner module#
- class renate.updaters.learner.RenateLightningModule(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#
Bases:
LightningModule,ABCBase class for LightningModules, which implement metric logging and basic training logic.
The
RenateLightningModuleis aLightningModule, but provides additional hook functions called byModelUpdater. These hooks are:on_model_update_start, which is called in the beginning of amodel update. We expect this to return train and (optionally) validation data loader(s).
on_model_update_end, which is called in the end of a model update.
- Parameters:
model¶ (
RenateModule) – The model to be trained.optimizer¶ (
Callable[[List[Parameter]],Optimizer]) – Partial optimizer used to create an optimizer by passing the model parameters.learning_rate_scheduler¶ (
Optional[Callable[[Optimizer],_LRScheduler]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.learning_rate_scheduler_interval¶ (
Literal['epoch','step']) – When to update the learning rate scheduler. Options:epochandstep.batch_size¶ (
int) – Training batch size.logged_metrics¶ (
Optional[Dict[str,Metric]]) – Metrics logged additional to the default ones.seed¶ (
int) – Seerenate.models.utils.get_generator().mask_unused_classes¶ (
bool) – Flag to use if logits corresponding to unused classes are to be ignored in the loss computation. Possibly useful for class incremental learning.
- is_logged_metric(metric_name)[source]#
Returns
Trueif there is a metric with namemetric_name.- Return type:
bool
- on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
- Return type:
None
- val_dataloader()[source]#
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set
reload_dataloaders_every_n_epochsto a positive integer.It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return type:
Optional[DataLoader]- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()will have an argumentdataloader_idxwhich matches the order here.
- training_step(batch, batch_idx)[source]#
PyTorch Lightning function to return the training loss.
- Return type:
Union[Tensor,Dict[str,Any]]
- training_step_end(step_output)[source]#
PyTorch Lightning function to perform after the training step.
- Return type:
Union[Tensor,Dict[str,Any]]
- training_epoch_end(outputs)[source]#
PyTorch Lightning function to run at the end of training epoch.
- Return type:
None
- validation_step(batch, batch_idx)[source]#
PyTorch Lightning function to estimate validation metrics.
- Return type:
None
- class renate.updaters.learner.Learner(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#
Bases:
RenateLightningModule,ABCBase class for Learners, which encapsulate the core CL methodologies.
The
Learneris aLightningModule, but provides additional hook functions called byModelUpdater. These hooks are:Learner.on_model_update_start, which is called in the beginning of amodel update. We expect this to return train and (optionally) validation data loader(s).
Learner.on_model_update_end, which is called in the end of a model update.
This base class implements a basic training loop without any mechanism to counteract forgetting.
- Parameters:
model¶ (
RenateModule) – The model to be trained.optimizer¶ (
Callable[[List[Parameter]],Optimizer]) – Partial optimizer used to create an optimizer by passing the model parameters.learning_rate_scheduler¶ (
Optional[Callable[[Optimizer],_LRScheduler]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.learning_rate_scheduler_interval¶ (
Literal['epoch','step']) – When to update the learning rate scheduler. Options:epochandstep.batch_size¶ (
int) – Training batch size.train_transform¶ (
Optional[Callable]) – The transformation applied during training.train_target_transform¶ (
Optional[Callable]) – The target transformation applied during testing.test_transform¶ (
Optional[Callable]) – The transformation at test time.test_target_transform¶ (
Optional[Callable]) – The target transformation at test time.logged_metrics¶ (
Optional[Dict[str,Metric]]) – Metrics logged additional to the default ones.seed¶ (
int) – Seerenate.models.utils.get_generator().
- on_save_checkpoint(checkpoint)[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint¶ (
Dict[str,Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.- Return type:
None
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint)[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()this is your chance to restore this.- Parameters:
checkpoint¶ (
Dict[str,Any]) – Loaded checkpoint
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
- Return type:
None
- val_dataloader()[source]#
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set
reload_dataloaders_every_n_epochsto a positive integer.It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return type:
Optional[DataLoader]- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()will have an argumentdataloader_idxwhich matches the order here.
- class renate.updaters.learner.ReplayLearner(memory_size, batch_size=32, batch_memory_frac=0.5, buffer_transform=None, buffer_target_transform=None, seed=0, **kwargs)[source]#
Bases:
Learner,ABCBase class for Learners which use a buffer to store data and reuse it in future updates.
- Parameters:
memory_size¶ (
int) – The maximum size of the memory.batch_memory_frac¶ (
float) – Fraction of the batch that is sampled from rehearsal memory.buffer_transform¶ (
Optional[Callable]) – The transformation to be applied to the memory buffer data samples.buffer_target_transform¶ (
Optional[Callable]) – The target transformation to be applied to the memory buffer target samples.seed¶ (
int) – Seerenate.models.utils.get_generator().
- on_save_checkpoint(checkpoint)[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint¶ (
Dict[str,Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.- Return type:
None
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint)[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()this is your chance to restore this.- Parameters:
checkpoint¶ (
Dict[str,Any]) – Loaded checkpoint- Return type:
None
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.