renate.benchmark.models.l2p module#
- class renate.benchmark.models.l2p.PromptPool(pool_size=10, pool_selection_size=5, prompt_size=5, prompt_key_dim=768, embedding_dim=768, train_prompt_keys=True, similarity_fn='cosine', per_batch_prompt=True)[source]#
Bases:
Module
Implements the prompt pool for L2P
- Parameters:
pool_size¶ (
int
) – Total size of the prompt pool.pool_selection_size¶ (
int
) – Number of prompts to select from the pool.prompt_size¶ (
int
) – Number of tokens each prompt is equivalent to.prompt_key_dim¶ (
int
) – Dimensions of the prompt key used to compute similarity. It has to be same to the dimensions ofx
in forward.embedding_dim¶ (
int
) – Output dimension of the token/patch embedding layertrain_prompt_keys¶ (
bool
) – Whether to train the prompt keys. Currently unused.similarity_fn¶ (
Union
[Callable
,str
]) – Similarity function between input features and prompt keysper_batch_prompt¶ (
bool
) – Flag to use the same prompts for all elements in the batch
- class renate.benchmark.models.l2p.PromptedTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True)[source]#
Bases:
Module
This generic module is the basic prompted transformer. It takes in a model string and creates the appropriate transformer (ViT or Text transformer). If no prompts are provided in the forward call, image/text features are returned. If a prompt is provided, it is concatenated to the embedding layer output and the resultant features are returned.
- Parameters:
pretrained_model_name_or_path¶ – A string that denotes which pretrained model from the HF hub to use.
num_outputs¶ (
int
) – Size of the output.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.
- class renate.benchmark.models.l2p.LearningToPromptTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True, pool_size=10, pool_selection_size=5, prompt_size=5, prompt_key_dim=768, train_prompt_keys=True, similarity_fn='cosine', per_batch_prompt=True, prompt_embedding_features='cls', patch_pooler='prompt_mean')[source]#
Bases:
RenateBenchmarkingModule
Implements the vision transformer with prompt pool described in Wang, Zifeng, et al. “Learning to prompt for continual learning.” Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
- Parameters:
pretrained_model_name_or_path¶ – A string that denotes which pretrained model from the HF hub to use. If provided, it overrides other arguments about architecture.
image_size¶ (
int
) – Size of the input image.patch_size¶ (
int
) – Size of the patches.num_layers¶ (
int
) – Number of Encoder layers.num_heads¶ (
int
) – Number of Attention heads.hidden_dim¶ (
int
) – Size of the Encoder’s hidden state.mlp_dim¶ (
int
) – Size of the intermediate Multi-layer Perceptron in the Encoder.dropout¶ (
float
) – Dropout probability.attention_dropout¶ (
float
) – Dropout probability for the attention in the Multi-head Attention layer.num_outputs¶ (
int
) – Size of the output.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.pool_size¶ (
int
) – Total size of the prompt pool.pool_selection_size¶ (
int
) – Number of prompts to select from the pool.prompt_size¶ (
int
) – Number of tokens each prompt is equivalent to.prompt_key_dim¶ (
int
) – Dimensions of the prompt key used to compute similarity. It has to be same to the dimensions ofx
in forward.train_prompt_keys¶ (
bool
) – Whether to train the prompt keys. Currently unused.similarity_fn¶ (
Union
[Callable
,str
]) – Similarity function between input features and prompt keys.per_batch_prompt¶ (
bool
) – Flag to use the same prompts for all elements in the batch.prompt_embedding_features¶ (
str
) – Image feature type used to compute the similarity to prompt keys.patch_pooler¶ (
str
) – Features to feed the classifier.