renate.updaters.experimental.l2p module#

class renate.updaters.experimental.l2p.LearningToPromptLearner(prompt_sim_loss_weight=0.5, **kwargs)[source]#

Bases: Learner

Learner for learning to prompt

This is identical to the base learner with an addition of loss term. TODO: Make this loss a component.

Parameters:

prompt_sim_loss_weight (float) – Loss weight for the prompt key - image representation similarity

training_step(batch, batch_idx)[source]#

PyTorch Lightning function to return the training loss.

Return type:

Union[Tensor, Dict[str, Any]]

class renate.updaters.experimental.l2p.LearningToPromptReplayLearner(prompt_sim_loss_weight=0.5, **kwargs)[source]#

Bases: OfflineExperienceReplayLearner

L2P with an off-line ER learner.

The model will be trained on weighted mixture of losses computed on the new data and a replay buffer. In contrast to the online version, the buffer will only be updated after training has terminated.

Parameters:

prompt_sim_loss_weight (float) – Loss weight for the prompt key - image representation similarity

training_step(batch, batch_idx)[source]#

PyTorch Lightning function to return the training loss.

Return type:

Union[Tensor, Dict[str, Any]]

class renate.updaters.experimental.l2p.LearningToPromptModelUpdater(model, loss_fn, optimizer, batch_size=32, seed=0, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', prompt_sim_loss_weight=0.5, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: SingleTrainingLoopUpdater

class renate.updaters.experimental.l2p.LearningToPromptReplayModelUpdater(model, loss_fn, optimizer, memory_size, batch_memory_frac=0.5, loss_weight_new_data=None, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', prompt_sim_loss_weight=0.5, batch_size=32, input_state_folder=None, output_state_folder=None, max_epochs=50, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', seed=0, deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: SingleTrainingLoopUpdater