renate.updaters.experimental.speft module#

class renate.updaters.experimental.speft.SPeftLearner(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#

Bases: Learner

Learner to implement S-Prompts from

`Wang, Yabin, et.al . "S-prompts learning with pre-trained transformers: An occam’s razor for domain incremental learning."  # noqa: E501 Advances in Neural Information Processing Systems 35 (2022): 5682-5695.`

Parameters:
  • model (RenateModule) – The SPromptTransformer model to be trained.

  • loss_fn (Module) – Loss function to be trained with.

  • optimizer (Callable[[List[Parameter]], Optimizer]) – Partial optimizer used to create an optimizer by passing the model parameters.

  • learning_rate_scheduler (Optional[Callable[[Optimizer], _LRScheduler]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.

  • learning_rate_scheduler_interval (Literal['epoch', 'step']) – When to update the learning rate scheduler. Options: epoch and step.

  • batch_size (int) – Training batch size.

  • train_transform (Optional[Callable]) – The transformation applied during training.

  • train_target_transform (Optional[Callable]) – The target transformation applied during testing.

  • test_transform (Optional[Callable]) – The transformation at test time.

  • test_target_transform (Optional[Callable]) – The target transformation at test time.

  • logged_metrics (Optional[Dict[str, Metric]]) – Metrics logged additional to the default ones.

  • seed (int) – See renate.models.utils.get_generator().

  • mask_unused_classes (bool) – Masking logits corresponding to unused classes. Useful only for class incremental problems. Defaults to defaults.MASK_UNUSED_CLASSES.

on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#

A custom on_model_update_start hook for S-Peft methods.

Here, we iterate oer the train data set and extract features. These features used to compute the task prototypes by the update_task_identifier call. Having this function in the model update start instead of end results in val metrics being reflective of test accuracy.

Return type:

None

setup(stage)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Return type:

None

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx)[source]#

Explicitly setting grads to None instead of zero.

Return type:

None

class renate.updaters.experimental.speft.SPeftModelUpdater(model, loss_fn, optimizer, batch_size=32, seed=0, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#

Bases: SingleTrainingLoopUpdater