renate.benchmark.models.vision_transformer module#
- class renate.benchmark.models.vision_transformer.FeatureExtractorViTModel(config, add_pooling_layer=True, use_mask_token=False)[source]#
Bases:
ViTModel
This class directly outputs [CLS] features if cls_feat is True else returns per patch embeddings
- forward(pixel_values=None, head_mask=None, output_attentions=None, output_hidden_states=None, interpolate_pos_encoding=None, return_dict=None, cls_feat=True, task_id='default_task')[source]#
Output has patch embeddings and the pooled output. We extract pooled CLS out by taking the second element.
- Return type:
Union
[Tuple
,BaseModelOutputWithPooling
]
- class renate.benchmark.models.vision_transformer.VisionTransformer(pretrained_model_name_or_path=None, image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True)[source]#
Bases:
RenateBenchmarkingModule
Vision Transformer base model.
TODO: Fix citation Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).
- Parameters:
pretrained_model_name_or_path¶ (
Optional
[str
]) – A string that denotes which pretrained model from the HF hub to use. If provided, it overrides other arguments about architecture.image_size¶ (
int
) – Size of the input image.patch_size¶ (
int
) – Size of the patches.num_layers¶ (
int
) – Number of Encoder layers.num_heads¶ (
int
) – Number of Attention heads.hidden_dim¶ (
int
) – Size of the Encoder’s hidden state.mlp_dim¶ (
int
) – Size of the intermediate Multi-layer Perceptron in the Encoder.dropout¶ (
float
) – Dropout probability.attention_dropout¶ (
float
) – Dropout probability for the attention in the Multi-head Attention layer.num_outputs¶ (
int
) – Size of the output.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.
- class renate.benchmark.models.vision_transformer.VisionTransformerCIFAR(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.vision_transformer.VisionTransformerB16(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.vision_transformer.VisionTransformerB32(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.vision_transformer.VisionTransformerL16(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.vision_transformer.VisionTransformerL32(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.vision_transformer.VisionTransformerH14(**kwargs)[source]#
Bases:
VisionTransformer