renate.benchmark.models.vision_transformer module#

class renate.benchmark.models.vision_transformer.FeatureExtractorViTModel(config, add_pooling_layer=True, use_mask_token=False)[source]#

Bases: ViTModel

This class directly outputs [CLS] features if cls_feat is True else returns per patch embeddings

forward(pixel_values=None, head_mask=None, output_attentions=None, output_hidden_states=None, interpolate_pos_encoding=None, return_dict=None, cls_feat=True, task_id='default_task')[source]#

Output has patch embeddings and the pooled output. We extract pooled CLS out by taking the second element.

Return type:

Union[Tuple, BaseModelOutputWithPooling]

training: bool#
class renate.benchmark.models.vision_transformer.VisionTransformer(pretrained_model_name_or_path=None, image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True)[source]#

Bases: RenateBenchmarkingModule

Vision Transformer base model.

TODO: Fix citation Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).

Parameters:
  • pretrained_model_name_or_path (Optional[str]) – A string that denotes which pretrained model from the HF hub to use. If provided, it overrides other arguments about architecture.

  • image_size (int) – Size of the input image.

  • patch_size (int) – Size of the patches.

  • num_layers (int) – Number of Encoder layers.

  • num_heads (int) – Number of Attention heads.

  • hidden_dim (int) – Size of the Encoder’s hidden state.

  • mlp_dim (int) – Size of the intermediate Multi-layer Perceptron in the Encoder.

  • dropout (float) – Dropout probability.

  • attention_dropout (float) – Dropout probability for the attention in the Multi-head Attention layer.

  • num_outputs (int) – Size of the output.

  • prediction_strategy (Optional[PredictionStrategy]) – Continual learning strategies may alter the prediction at train or test time.

  • add_icarl_class_means (bool) – If True, additional parameters used only by the ICaRLModelUpdater are added. Only required when using that updater.

get_features(*args, **kwargs)[source]#
training: bool#
class renate.benchmark.models.vision_transformer.VisionTransformerCIFAR(**kwargs)[source]#

Bases: VisionTransformer

training: bool#
class renate.benchmark.models.vision_transformer.VisionTransformerB16(**kwargs)[source]#

Bases: VisionTransformer

training: bool#
class renate.benchmark.models.vision_transformer.VisionTransformerB32(**kwargs)[source]#

Bases: VisionTransformer

training: bool#
class renate.benchmark.models.vision_transformer.VisionTransformerL16(**kwargs)[source]#

Bases: VisionTransformer

training: bool#
class renate.benchmark.models.vision_transformer.VisionTransformerL32(**kwargs)[source]#

Bases: VisionTransformer

training: bool#
class renate.benchmark.models.vision_transformer.VisionTransformerH14(**kwargs)[source]#

Bases: VisionTransformer

training: bool#