renate.benchmark.models.spromptmodel module#
- class renate.benchmark.models.spromptmodel.PromptPool(prompt_size=10, embedding_size=768, current_update_id=0)[source]#
Bases:
ModuleImplements a pool of prompts to be used in for S-Prompts.
- Parameters:
- forward(id)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
ParameterNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class renate.benchmark.models.spromptmodel.SPromptTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True, prompt_size=10, task_id=0, clusters_per_task=5, per_task_classifier=False)[source]#
Bases:
RenateBenchmarkingModule- Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al .”S-prompts
learning with pre-trained transformers: An occam’s razor for domain incremental learning.” Advances in Neural Information Processing Systems 35 (2022): 5682-5695.
- Parameters:
pretrained_model_name_or_path¶ – A string that denotes which pretrained model from the HF hub to use.
image_size¶ (
int) – Image size. Used ifpretrained_model_name_or_pathis not set .patch_size¶ (
int) – Patch size to be extracted. Used ifpretrained_model_name_or_pathis not set .num_layers¶ (
int) – Num of transformer layers. Used only ifpretrained_model_name_or_pathis not set .num_heads¶ (
int) – Num heads in MHSA. Used only ifpretrained_model_name_or_pathis not set .hidden_dim¶ (
int) – Hidden dimension of transformers. Used only ifpretrained_model_name_or_pathis not set .mlp_dim¶ (
int) – _description_. Used only ifpretrained_model_name_or_pathis not set .dropout¶ (
float) – _description_. Used only ifpretrained_model_name_or_pathis not set .attention_dropout¶ (
float) – _description_. Used only ifpretrained_model_name_or_pathis not set .num_outputs¶ (
int) – Number of output classes of the output. Defaults to 10.prediction_strategy¶ (
Optional[PredictionStrategy]) – Continual learning strategies may alter the prediction at train or test time. Defaults to None.add_icarl_class_means¶ (
bool) – IfTrue, additional parameters used only by theICaRLModelUpdaterare added. Only required when using that updater.prompt_size¶ (
int) – Equivalent to number of input tokens used per update . Defaults to 10.task_id¶ (
int) – Internal variable used to increment update id. Shouldn’t be set by user. Defaults to 0.clusters_per_task¶ (
int) – Number clusters in k-means used for task identification. Defaults to 5.per_task_classifier¶ (
bool) – Flag to share or use a common classifier head for all tasks. Defaults to False.