renate.benchmark.models.spromptmodel module#
- class renate.benchmark.models.spromptmodel.PromptPool(prompt_size=10, embedding_size=768, current_update_id=0)[source]#
Bases:
Module
Implements a pool of prompts to be used in for S-Prompts.
- Parameters:
- forward(id)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Parameter
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class renate.benchmark.models.spromptmodel.SPromptTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True, prompt_size=10, task_id=0, clusters_per_task=5, per_task_classifier=False)[source]#
Bases:
RenateBenchmarkingModule
- Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al .”S-prompts
learning with pre-trained transformers: An occam’s razor for domain incremental learning.” Advances in Neural Information Processing Systems 35 (2022): 5682-5695.
- Parameters:
pretrained_model_name_or_path¶ – A string that denotes which pretrained model from the HF hub to use.
image_size¶ (
int
) – Image size. Used ifpretrained_model_name_or_path
is not set .patch_size¶ (
int
) – Patch size to be extracted. Used ifpretrained_model_name_or_path
is not set .num_layers¶ (
int
) – Num of transformer layers. Used only ifpretrained_model_name_or_path
is not set .num_heads¶ (
int
) – Num heads in MHSA. Used only ifpretrained_model_name_or_path
is not set .hidden_dim¶ (
int
) – Hidden dimension of transformers. Used only ifpretrained_model_name_or_path
is not set .mlp_dim¶ (
int
) – _description_. Used only ifpretrained_model_name_or_path
is not set .dropout¶ (
float
) – _description_. Used only ifpretrained_model_name_or_path
is not set .attention_dropout¶ (
float
) – _description_. Used only ifpretrained_model_name_or_path
is not set .num_outputs¶ (
int
) – Number of output classes of the output. Defaults to 10.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time. Defaults to None.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.prompt_size¶ (
int
) – Equivalent to number of input tokens used per update . Defaults to 10.task_id¶ (
int
) – Internal variable used to increment update id. Shouldn’t be set by user. Defaults to 0.clusters_per_task¶ (
int
) – Number clusters in k-means used for task identification. Defaults to 5.per_task_classifier¶ (
bool
) – Flag to share or use a common classifier head for all tasks. Defaults to False.