renate.updaters.experimental.speft module#
- class renate.updaters.experimental.speft.SPeftLearner(model, loss_fn, optimizer, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', batch_size=32, train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, logged_metrics=None, seed=0, mask_unused_classes=False)[source]#
Bases:
Learner
- Learner to implement S-Prompts from
`Wang, Yabin, et.al . "S-prompts learning with pre-trained transformers: An occam’s razor for domain incremental learning." # noqa: E501 Advances in Neural Information Processing Systems 35 (2022): 5682-5695.`
- Parameters:
model¶ (
RenateModule
) – The SPromptTransformer model to be trained.loss_fn¶ (
Module
) – Loss function to be trained with.optimizer¶ (
Callable
[[List
[Parameter
]],Optimizer
]) – Partial optimizer used to create an optimizer by passing the model parameters.learning_rate_scheduler¶ (
Optional
[Callable
[[Optimizer
],_LRScheduler
]]) – Partial object of learning rate scheduler that will be created by passing the optimizer.learning_rate_scheduler_interval¶ (
Literal
['epoch'
,'step'
]) – When to update the learning rate scheduler. Options:epoch
andstep
.batch_size¶ (
int
) – Training batch size.train_transform¶ (
Optional
[Callable
]) – The transformation applied during training.train_target_transform¶ (
Optional
[Callable
]) – The target transformation applied during testing.test_transform¶ (
Optional
[Callable
]) – The transformation at test time.test_target_transform¶ (
Optional
[Callable
]) – The target transformation at test time.logged_metrics¶ (
Optional
[Dict
[str
,Metric
]]) – Metrics logged additional to the default ones.seed¶ (
int
) – Seerenate.models.utils.get_generator()
.mask_unused_classes¶ (
bool
) – Masking logits corresponding to unused classes. Useful only for class incremental problems. Defaults to defaults.MASK_UNUSED_CLASSES.
- on_model_update_start(train_dataset, val_dataset, train_dataset_collate_fn=None, val_dataset_collate_fn=None, task_id=None)[source]#
A custom on_model_update_start hook for S-Peft methods.
Here, we iterate oer the train data set and extract features. These features used to compute the task prototypes by the
update_task_identifier
call. Having this function in the model update start instead of end results in val metrics being reflective of test accuracy.- Return type:
None
- setup(stage)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage¶ (
str
) – either'fit'
,'validate'
,'test'
, or'predict'
- Return type:
None
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- class renate.updaters.experimental.speft.SPeftModelUpdater(model, loss_fn, optimizer, batch_size=32, seed=0, learner_kwargs=None, input_state_folder=None, output_state_folder=None, max_epochs=50, learning_rate_scheduler=None, learning_rate_scheduler_interval='epoch', train_transform=None, train_target_transform=None, test_transform=None, test_target_transform=None, buffer_transform=None, buffer_target_transform=None, metric=None, mode='min', logged_metrics=None, early_stopping_enabled=False, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32', deterministic_trainer=False, gradient_clip_val=None, gradient_clip_algorithm=None, mask_unused_classes=False)[source]#
Bases:
SingleTrainingLoopUpdater