Source code for renate.benchmark.models.vision_transformer
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Tuple, Union
import torch
from transformers import ViTConfig, ViTModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from renate import defaults
from renate.benchmark.models.base import RenateBenchmarkingModule
from renate.models.prediction_strategies import PredictionStrategy
[docs]
class FeatureExtractorViTModel(ViTModel):
"""This class directly outputs [CLS] features if cls_feat is True else returns per patch
embeddings
"""
[docs]
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
cls_feat: bool = True,
task_id: str = defaults.TASK_ID,
) -> Union[Tuple, BaseModelOutputWithPooling]:
"""Output has patch embeddings and the pooled output. We extract pooled CLS out by
taking the second element.
"""
out_to_filter = super().forward(
pixel_values,
head_mask,
output_attentions,
output_hidden_states,
interpolate_pos_encoding,
return_dict,
)
if isinstance(out_to_filter, BaseModelOutputWithPooling):
return out_to_filter.pooler_output if cls_feat else out_to_filter.last_hidden_state
return out_to_filter[0][:, 0] if cls_feat else out_to_filter[0]
[docs]
class VisionTransformer(RenateBenchmarkingModule):
"""Vision Transformer base model.
TODO: Fix citation
Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition
at scale."
arXiv preprint arXiv:2010.11929 (2020).
Args:
pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub
to use. If provided, it overrides other arguments about architecture.
image_size: Size of the input image.
patch_size: Size of the patches.
num_layers: Number of Encoder layers.
num_heads: Number of Attention heads.
hidden_dim: Size of the Encoder's hidden state.
mlp_dim: Size of the intermediate Multi-layer Perceptron in the Encoder.
dropout: Dropout probability.
attention_dropout: Dropout probability for the attention in the Multi-head Attention layer.
num_outputs: Size of the output.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time.
add_icarl_class_means: If ``True``, additional parameters used only by the
``ICaRLModelUpdater`` are added. Only required when using that updater.
"""
def __init__(
self,
pretrained_model_name_or_path: Optional[str] = None,
image_size: int = 32,
patch_size: int = 4,
num_layers: int = 12,
num_heads: int = 12,
hidden_dim: int = 768,
mlp_dim: int = 3072,
dropout: float = 0.1,
attention_dropout: float = 0.1,
num_outputs: int = 10,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
) -> None:
if pretrained_model_name_or_path:
model = FeatureExtractorViTModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
return_dict=False,
add_pooling_layer=False,
)
constructor_args = dict()
else:
model_config = ViTConfig(
hidden_size=hidden_dim,
num_hidden_layers=num_layers,
num_attention_heads=num_heads,
intermediate_size=mlp_dim,
hidden_act="gelu",
hidden_dropout_prob=dropout,
attention_probs_dropout_prob=attention_dropout,
layer_norm_eps=1e-6,
image_size=image_size,
patch_size=patch_size,
num_channels=3,
qkv_bias=True,
return_dict=False,
)
model = FeatureExtractorViTModel(config=model_config, add_pooling_layer=False)
constructor_args = {
"image_size": image_size,
"patch_size": patch_size,
"num_layers": num_layers,
"num_heads": num_heads,
"hidden_dim": hidden_dim,
"mlp_dim": mlp_dim,
"dropout": dropout,
"attention_dropout": attention_dropout,
}
super().__init__(
embedding_size=model.config.hidden_size,
num_outputs=num_outputs,
constructor_arguments=constructor_args,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
self._backbone = model
[docs]
def get_features(self, *args, **kwargs):
# This is need as a shortcut to not call the base class's forward and directly call the
# backbone's forward. Used only in L2P.
return self._backbone(*args, **kwargs)
[docs]
class VisionTransformerCIFAR(VisionTransformer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(
image_size=32,
patch_size=4,
num_layers=3,
num_heads=3,
hidden_dim=192,
mlp_dim=768,
**kwargs,
)
[docs]
class VisionTransformerB16(VisionTransformer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(
pretrained_model_name_or_path="google/vit-base-patch16-224",
**kwargs,
)
[docs]
class VisionTransformerB32(VisionTransformer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(
pretrained_model_name_or_path="google/vit-base-patch32-224-in21k",
**kwargs,
)
[docs]
class VisionTransformerL16(VisionTransformer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(
pretrained_model_name_or_path="google/vit-large-patch16-224-in21k",
**kwargs,
)
[docs]
class VisionTransformerL32(VisionTransformer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(
pretrained_model_name_or_path="google/vit-large-patch32-224-in21k",
**kwargs,
)
[docs]
class VisionTransformerH14(VisionTransformer):
def __init__(self, **kwargs: Any) -> None:
super().__init__(
pretrained_model_name_or_path="google/vit-huge-patch14-224-in21k",
**kwargs,
)