Source code for renate.benchmark.models.l2p
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import functools
import logging
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
from renate import defaults
from renate.benchmark.models.base import RenateBenchmarkingModule
from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer
from renate.benchmark.models.vision_transformer import VisionTransformer
from renate.models.prediction_strategies import PredictionStrategy
logger = logging.getLogger(__name__)
[docs]
class PromptPool(nn.Module):
"""Implements the prompt pool for L2P
Args:
pool_size: Total size of the prompt pool.
pool_selection_size: Number of prompts to select from the pool.
prompt_size: Number of tokens each prompt is equivalent to.
prompt_key_dim: Dimensions of the prompt key used to compute similarity. It has to be same
to the dimensions of `x` in forward.
embedding_dim: Output dimension of the token/patch embedding layer
train_prompt_keys: Whether to train the prompt keys. Currently unused.
similarity_fn: Similarity function between input features and prompt keys
per_batch_prompt: Flag to use the same prompts for all elements in the batch
"""
def __init__(
self,
pool_size: int = 10,
pool_selection_size: int = 5,
prompt_size: int = 5,
prompt_key_dim: int = 768,
embedding_dim: int = 768,
train_prompt_keys: bool = True,
similarity_fn: Union[Callable, str] = "cosine",
per_batch_prompt: bool = True,
):
super().__init__()
self._M = pool_size ## total pool size
self._N = pool_selection_size ## number of prompts selected per input
self._Lp = prompt_size ## each prompt is equal to how many tokens
self._d = embedding_dim ##
self._pd = prompt_key_dim
self._per_batch_prompt = per_batch_prompt
self._parse_similarity_fn(similarity_fn)
self.train_prompt_keys = train_prompt_keys ## This is unused for now
self.prompt_pool = nn.Parameter(torch.empty((self._M, self._Lp, self._d)).uniform_(-1, 1))
self.prompt_keys = nn.Parameter(torch.empty((self._M, self._pd)).uniform_(-1, 1))
self.key_hist = torch.zeros((self._M,), dtype=torch.float32)
def _parse_similarity_fn(self, similarity_fn: Union[Callable, str]) -> None:
if callable(similarity_fn):
self.similarity_fn = similarity_fn
elif not isinstance(similarity_fn, str):
raise ValueError(
"similarity_fn has to be a callable or a string representing similarity metric. "
"But got {similarity_fn}"
)
elif similarity_fn == "cosine":
normalization_fn = functools.partial(torch.nn.functional.normalize, p=2)
self.similarity_fn = lambda x, y: normalization_fn(x).matmul(normalization_fn(y).t())
else:
raise ValueError(
f"Currently only cosine similarity is supported, but got {similarity_fn}"
)
[docs]
def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTensor] = None):
"""
Args:
x: Image features extracted. It can be [CLS] token or something else of
dimension B x self.pd..
manual_prompt_indices: Indices to manually select prompts from pool, instead of
selecting from
"""
if manual_prompt_indices is None:
similarity_matrix = self.similarity_fn(x, self.prompt_keys)
_, idx = torch.topk(similarity_matrix, k=self._N, dim=1)
if self._per_batch_prompt:
prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True)
if prompt_id.shape[0] < self._M:
## The logic for this is taken from the public l2p implementation.
temp_pid = torch.full((self._M,), idx.min(), device=prompt_id.device)
temp_pid[: prompt_id.shape[0]] = prompt_id
prompt_id = temp_pid
temp_idc = torch.zeros((self._M,), device=id_counts.device)
temp_idc[: id_counts.shape[0]] = id_counts
id_counts = temp_idc
_, major_idx = torch.topk(id_counts, k=self._N)
idx = prompt_id[major_idx].expand(x.shape[0], -1) # B, top_k
loss_value = similarity_matrix[:, idx].sum() / (x.shape[0] * x.shape[0])
else:
idx = manual_prompt_indices # should be of size B, top_k
loss_value = torch.tensor(0.0, device=x.device)
selected_prompts = self.prompt_pool[idx].flatten(1, 2)
return selected_prompts, loss_value
[docs]
class PromptedTransformer(nn.Module):
"""This generic module is the basic prompted transformer. It takes in a model string and creates
the appropriate transformer (ViT or Text transformer). If no prompts are provided in the forward
call, image/text features are returned. If a prompt is provided, it is concatenated to the
embedding layer output and the resultant features are returned.
Args:
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub
to use.
num_outputs: Size of the output.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
"""
def __init__(
self,
pretrained_model_name_or_path="google/vit-base-patch16-224",
image_size: int = 32,
patch_size: int = 4,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.1,
attention_dropout: float = 0.1,
num_outputs: int = 10,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
) -> None:
super().__init__()
if "vit" in pretrained_model_name_or_path:
self.transformer = VisionTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
self.is_text_transformer = False
else:
self.transformer = HuggingFaceSequenceClassificationTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
num_outputs=num_outputs,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
for named_param, value in self.transformer.named_parameters():
if value.shape[0] == self.transformer._backbone.config.vocab_size:
self.word_embeddings = self.transformer.get_submodule(
named_param.replace(".weight", "")
)
break
self.is_text_transformer = True
self.transformer._tasks_params.clear()
self.transformer.eval()
for p in self.transformer.parameters():
p.requires_grad_(False)
[docs]
def forward(
self, x: torch.Tensor, prompt: Optional[torch.Tensor] = None, cls_feat: bool = True
) -> torch.Tensor:
"""
Args:
x: Input torch tensor.
prompt: Prompt tensor. Defaults to None.
cls_feat: Whether to extract [CLS] token or to return full feature tensor.
Ignored for text transformer. Defaults to True.
"""
if prompt is None:
return (
self.transformer.get_features(x)
if self.is_text_transformer
else self.transformer.get_features(x, cls_feat=cls_feat)
)
# text transformers dont support cls_feat.
elif self.is_text_transformer:
# The implicit assumption here is that x for text transformers is the input_ids.
# This simplified forward pass has 4 steps:
# 1. Get prompts
# 2. Get embeddings from inputs.
# 3. Concat prompt and inputs
# 4. Forward prop inputs_embeds to get the features.
inputs_embeds = self.word_embeddings(x["input_ids"])
if prompt.size(0) != inputs_embeds.size(0):
prompt = prompt.unsqueeze(0).expand(
inputs_embeds.size(0), -1, -1
) # Expand one prompt to batch size
inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1)
return self.transformer.get_features({"inputs_embeds": inputs_embeds})
else:
patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x)
if prompt.size(0) != x.size(0):
prompt = prompt.unsqueeze(0).expand(
x.size(0), -1, -1
) # Expand one prompt to batch size# Expand one prompt to batch size
input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1)
encoded_features = self.transformer.get_submodule("_backbone.encoder")(
input_concat_prompt, return_dict=False
)[0]
encoded_features = self.transformer.get_submodule("_backbone.layernorm")(
encoded_features
)
return encoded_features[:, 0, :] if cls_feat else encoded_features
[docs]
class LearningToPromptTransformer(RenateBenchmarkingModule):
"""
Implements the vision transformer with prompt pool described in
Wang, Zifeng, et al. "Learning to prompt for continual learning." Proceedings of the
IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
Args:
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub
to use. If provided, it overrides other arguments about architecture.
image_size: Size of the input image.
patch_size: Size of the patches.
num_layers: Number of Encoder layers.
num_heads: Number of Attention heads.
hidden_dim: Size of the Encoder's hidden state.
mlp_dim: Size of the intermediate Multi-layer Perceptron in the Encoder.
dropout: Dropout probability.
attention_dropout: Dropout probability for the attention in the Multi-head Attention layer.
num_outputs: Size of the output.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
pool_size: Total size of the prompt pool.
pool_selection_size: Number of prompts to select from the pool.
prompt_size: Number of tokens each prompt is equivalent to.
prompt_key_dim: Dimensions of the prompt key used to compute similarity. It has to be same
to the dimensions of `x` in forward.
train_prompt_keys: Whether to train the prompt keys. Currently unused.
similarity_fn: Similarity function between input features and prompt keys.
per_batch_prompt: Flag to use the same prompts for all elements in the batch.
prompt_embedding_features: Image feature type used to compute the similarity to prompt keys.
patch_pooler: Features to feed the classifier.
"""
def __init__(
self,
pretrained_model_name_or_path="google/vit-base-patch16-224",
image_size: int = 32,
patch_size: int = 4,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.1,
attention_dropout: float = 0.1,
num_outputs: int = 10,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
pool_size: int = 10,
pool_selection_size: int = 5,
prompt_size: int = 5,
prompt_key_dim: int = 768,
train_prompt_keys: bool = True,
similarity_fn: Union[Callable, str] = "cosine",
per_batch_prompt: bool = True,
prompt_embedding_features: str = "cls",
patch_pooler: str = "prompt_mean",
) -> None:
transformer = PromptedTransformer(
pretrained_model_name_or_path=pretrained_model_name_or_path,
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout=dropout,
attention_dropout=attention_dropout,
num_outputs=num_outputs,
add_icarl_class_means=add_icarl_class_means,
prediction_strategy=prediction_strategy,
)
prompter = PromptPool(
embedding_dim=transformer.transformer._embedding_size,
pool_size=pool_size,
pool_selection_size=pool_selection_size,
prompt_size=prompt_size,
prompt_key_dim=prompt_key_dim,
train_prompt_keys=train_prompt_keys,
similarity_fn=similarity_fn,
per_batch_prompt=per_batch_prompt,
)
super().__init__(
embedding_size=transformer.transformer._embedding_size,
num_outputs=num_outputs,
constructor_arguments=dict(
**transformer.transformer._constructor_arguments,
pool_size=pool_size,
pool_selection_size=pool_selection_size,
prompt_size=prompt_size,
prompt_key_dim=prompt_key_dim,
train_prompt_keys=train_prompt_keys,
similarity_fn=similarity_fn,
per_batch_prompt=per_batch_prompt,
),
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter})
self._is_text_transformer = transformer.is_text_transformer
self.prompt_embedding_features = prompt_embedding_features
self.patch_pooler = patch_pooler
self.similarity_score: Optional[torch.Tensor] = None
assert self.prompt_embedding_features in [
"cls",
"mean",
], f"Invalid method to extract prompt embedding features. Got {prompt_embedding_features}"
assert self.patch_pooler in [
"cls",
"mean",
"prompt_mean",
], f"Invalid method to extract prompt embedding features. Got {patch_pooler}"
for p in self._backbone["prompter"].parameters():
p.requires_grad = True
# The backbone's forward is monkey-patched to allow the parent class' forward to work
# without any manual management.
self._backbone.forward = self.forward_for_monkey_patching
[docs]
def forward_for_monkey_patching(
self, x: torch.Tensor, task_id: str = defaults.TASK_ID
) -> torch.Tensor:
with torch.no_grad():
prompt_pool_input = self._backbone["transformer"](x, cls_feat=False)
if not self._is_text_transformer:
if self.prompt_embedding_features == "cls":
# retrieve cls token features. This is used in L2P paper.
prompt_pool_input = prompt_pool_input[:, 0, :]
elif self.prompt_embedding_features == "mean":
# compute mean patch features.
prompt_pool_input = prompt_pool_input[:, 1:, :].mean(1)
# Compute the prompts to be stacked
prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input)
self.similarity_score = prompt_similarity
encoded_features = self._backbone["transformer"](x, prompts, cls_feat=False)
if self._is_text_transformer:
return encoded_features
else:
if self.patch_pooler == "cls":
seq_cls_token = encoded_features[:, 0, :]
elif self.patch_pooler == "mean":
seq_cls_token = encoded_features[:, 1:, :].mean(1)
elif self.patch_pooler == "prompt_mean":
num_prompts = prompts.size(1)
seq_cls_token = encoded_features[:, -num_prompts:, :].mean(1)
return seq_cls_token