renate.updaters.learner module#
- class renate.updaters.learner.RenateLightningModule(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#
Bases:
LightningModule
,ABC
Base class for LightningModules, which implement metric logging and basic training logic.
The
RenateLightningModule
is aLightningModule
, but provides additional hook functions called byModelUpdater
. These hooks are:on_model_update_start
, which is called in the beginning of amodel update. We expect this to return train and (optionally) validation data loader(s).
on_model_update_end
, which is called in the end of a model update.
- Parameters:
model¶ (
RenateModule
) – The model to be trained.optimizer¶ (
Callable
[[List
[Parameter
]],Optimizer
]) – Partial optimizer used to create an optimizer by passing the model parameters.learning_rate_scheduler¶ (
Optional
[Callable
[[Optimizer
],_LRScheduler
]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.learning_rate_scheduler_interval¶ (
Literal
['epoch'
,'step'
]) – When to update the learning rate scheduler. Options:epoch
andstep
.batch_size¶ (
int
) – Training batch size.logged_metrics¶ (
Optional
[Dict
[str
,Metric
]]) – Metrics logged additional to the default ones.seed¶ (
int
) – Seerenate.models.utils.get_generator()
.mask_unused_classes¶ (
bool
) – Flag to use if logits corresponding to unused classes are to be ignored in the loss computation. Possibly useful for class incremental learning.
- is_logged_metric(metric_name)[source]#
Returns
True
if there is a metric with namemetric_name
.- Return type:
bool
- on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
- Return type:
None
- val_dataloader()[source]#
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set
reload_dataloaders_every_n_epochs
to a positive integer.It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return type:
Optional
[DataLoader
]- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- training_step(batch, batch_idx)[source]#
PyTorch Lightning function to return the training loss.
- Return type:
Union
[Tensor
,Dict
[str
,Any
]]
- training_step_end(step_output)[source]#
PyTorch Lightning function to perform after the training step.
- Return type:
Union
[Tensor
,Dict
[str
,Any
]]
- training_epoch_end(outputs)[source]#
PyTorch Lightning function to run at the end of training epoch.
- Return type:
None
- validation_step(batch, batch_idx)[source]#
PyTorch Lightning function to estimate validation metrics.
- Return type:
None
- class renate.updaters.learner.Learner(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#
Bases:
RenateLightningModule
,ABC
Base class for Learners, which encapsulate the core CL methodologies.
The
Learner
is aLightningModule
, but provides additional hook functions called byModelUpdater
. These hooks are:Learner.on_model_update_start
, which is called in the beginning of amodel update. We expect this to return train and (optionally) validation data loader(s).
Learner.on_model_update_end
, which is called in the end of a model update.
This base class implements a basic training loop without any mechanism to counteract forgetting.
- Parameters:
model¶ (
RenateModule
) – The model to be trained.optimizer¶ (
Callable
[[List
[Parameter
]],Optimizer
]) – Partial optimizer used to create an optimizer by passing the model parameters.learning_rate_scheduler¶ (
Optional
[Callable
[[Optimizer
],_LRScheduler
]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.learning_rate_scheduler_interval¶ (
Literal
['epoch'
,'step'
]) – When to update the learning rate scheduler. Options:epoch
andstep
.batch_size¶ (
int
) – Training batch size.train_transform¶ (
Optional
[Callable
]) – The transformation applied during training.train_target_transform¶ (
Optional
[Callable
]) – The target transformation applied during testing.test_transform¶ (
Optional
[Callable
]) – The transformation at test time.test_target_transform¶ (
Optional
[Callable
]) – The target transformation at test time.logged_metrics¶ (
Optional
[Dict
[str
,Metric
]]) – Metrics logged additional to the default ones.seed¶ (
int
) – Seerenate.models.utils.get_generator()
.
- on_save_checkpoint(checkpoint)[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.- Return type:
None
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint)[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()
this is your chance to restore this.- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – Loaded checkpoint
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
- Return type:
None
- val_dataloader()[source]#
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set
reload_dataloaders_every_n_epochs
to a positive integer.It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
validate()
prepare_data()
setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return type:
Optional
[DataLoader
]- Returns:
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.
- class renate.updaters.learner.ReplayLearner(memory_size, batch_size=32, batch_memory_frac=0.5, buffer_transform=None, buffer_target_transform=None, seed=0, **kwargs)[source]#
Bases:
Learner
,ABC
Base class for Learners which use a buffer to store data and reuse it in future updates.
- Parameters:
memory_size¶ (
int
) – The maximum size of the memory.batch_memory_frac¶ (
float
) – Fraction of the batch that is sampled from rehearsal memory.buffer_transform¶ (
Optional
[Callable
]) – The transformation to be applied to the memory buffer data samples.buffer_target_transform¶ (
Optional
[Callable
]) – The target transformation to be applied to the memory buffer target samples.seed¶ (
int
) – Seerenate.models.utils.get_generator()
.
- on_save_checkpoint(checkpoint)[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.- Return type:
None
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint)[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()
this is your chance to restore this.- Parameters:
checkpoint¶ (
Dict
[str
,Any
]) – Loaded checkpoint- Return type:
None
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.