renate.benchmark.models package#
- class renate.benchmark.models.MultiLayerPerceptron(num_inputs, num_outputs, num_hidden_layers, hidden_size, activation='ReLU', batch_normalization=False, prediction_strategy=None, add_icarl_class_means=True)[source]#
Bases:
RenateBenchmarkingModule
A simple Multi Layer Perceptron with hidden layers, activation and Batch Normalization if enabled.
- Parameters:
num_inputs¶ (
int
) – Number of input nodes.num_outputs¶ (
int
) – Number of output nodes.num_hidden_layers¶ (
int
) – Number of hidden layers.hidden_size¶ (
Union
[int
,List
[int
],Tuple
[int
]]) – Uniform hidden size or the list or tuple of hidden sizes for individual hidden layers.activation¶ (
str
) – Activation name, matching activation name intorch.nn
to be used between the hidden layers.batch_normalization¶ (
bool
) – Whether to use Batch Normalization after the activation. By default the Batch Normalization tracks the running statistics.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.
- class renate.benchmark.models.ResNet18(block=<class 'torchvision.models.resnet.BasicBlock'>, layers=[2, 2, 2, 2], cifar_stem=False, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.ResNet18CIFAR(block=<class 'torchvision.models.resnet.BasicBlock'>, layers=[2, 2, 2, 2], cifar_stem=True, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.ResNet34(block=<class 'torchvision.models.resnet.BasicBlock'>, layers=[3, 4, 6, 3], cifar_stem=False, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.ResNet34CIFAR(block=<class 'torchvision.models.resnet.BasicBlock'>, layers=[3, 4, 6, 3], cifar_stem=True, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.ResNet50(block=<class 'torchvision.models.resnet.Bottleneck'>, layers=[3, 4, 6, 3], cifar_stem=False, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.ResNet50CIFAR(block=<class 'torchvision.models.resnet.Bottleneck'>, layers=[3, 4, 6, 3], cifar_stem=True, **kwargs)[source]#
Bases:
ResNet
- class renate.benchmark.models.LearningToPromptTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True, pool_size=10, pool_selection_size=5, prompt_size=5, prompt_key_dim=768, train_prompt_keys=True, similarity_fn='cosine', per_batch_prompt=True, prompt_embedding_features='cls', patch_pooler='prompt_mean')[source]#
Bases:
RenateBenchmarkingModule
Implements the vision transformer with prompt pool described in Wang, Zifeng, et al. “Learning to prompt for continual learning.” Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
- Parameters:
pretrained_model_name_or_path¶ – A string that denotes which pretrained model from the HF hub to use. If provided, it overrides other arguments about architecture.
image_size¶ (
int
) – Size of the input image.patch_size¶ (
int
) – Size of the patches.num_layers¶ (
int
) – Number of Encoder layers.num_heads¶ (
int
) – Number of Attention heads.hidden_dim¶ (
int
) – Size of the Encoder’s hidden state.mlp_dim¶ (
int
) – Size of the intermediate Multi-layer Perceptron in the Encoder.dropout¶ (
float
) – Dropout probability.attention_dropout¶ (
float
) – Dropout probability for the attention in the Multi-head Attention layer.num_outputs¶ (
int
) – Size of the output.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.pool_size¶ (
int
) – Total size of the prompt pool.pool_selection_size¶ (
int
) – Number of prompts to select from the pool.prompt_size¶ (
int
) – Number of tokens each prompt is equivalent to.prompt_key_dim¶ (
int
) – Dimensions of the prompt key used to compute similarity. It has to be same to the dimensions ofx
in forward.train_prompt_keys¶ (
bool
) – Whether to train the prompt keys. Currently unused.similarity_fn¶ (
Union
[Callable
,str
]) – Similarity function between input features and prompt keys.per_batch_prompt¶ (
bool
) – Flag to use the same prompts for all elements in the batch.prompt_embedding_features¶ (
str
) – Image feature type used to compute the similarity to prompt keys.patch_pooler¶ (
str
) – Features to feed the classifier.
- class renate.benchmark.models.PromptedTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True)[source]#
Bases:
Module
This generic module is the basic prompted transformer. It takes in a model string and creates the appropriate transformer (ViT or Text transformer). If no prompts are provided in the forward call, image/text features are returned. If a prompt is provided, it is concatenated to the embedding layer output and the resultant features are returned.
- Parameters:
pretrained_model_name_or_path¶ – A string that denotes which pretrained model from the HF hub to use.
num_outputs¶ (
int
) – Size of the output.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.
- class renate.benchmark.models.SPromptTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True, prompt_size=10, task_id=0, clusters_per_task=5, per_task_classifier=False)[source]#
Bases:
RenateBenchmarkingModule
- Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al .”S-prompts
learning with pre-trained transformers: An occam’s razor for domain incremental learning.” Advances in Neural Information Processing Systems 35 (2022): 5682-5695.
- Parameters:
pretrained_model_name_or_path¶ – A string that denotes which pretrained model from the HF hub to use.
image_size¶ (
int
) – Image size. Used ifpretrained_model_name_or_path
is not set .patch_size¶ (
int
) – Patch size to be extracted. Used ifpretrained_model_name_or_path
is not set .num_layers¶ (
int
) – Num of transformer layers. Used only ifpretrained_model_name_or_path
is not set .num_heads¶ (
int
) – Num heads in MHSA. Used only ifpretrained_model_name_or_path
is not set .hidden_dim¶ (
int
) – Hidden dimension of transformers. Used only ifpretrained_model_name_or_path
is not set .mlp_dim¶ (
int
) – _description_. Used only ifpretrained_model_name_or_path
is not set .dropout¶ (
float
) – _description_. Used only ifpretrained_model_name_or_path
is not set .attention_dropout¶ (
float
) – _description_. Used only ifpretrained_model_name_or_path
is not set .num_outputs¶ (
int
) – Number of output classes of the output. Defaults to 10.prediction_strategy¶ (
Optional
[PredictionStrategy
]) – Continual learning strategies may alter the prediction at train or test time. Defaults to None.add_icarl_class_means¶ (
bool
) – IfTrue
, additional parameters used only by theICaRLModelUpdater
are added. Only required when using that updater.prompt_size¶ (
int
) – Equivalent to number of input tokens used per update . Defaults to 10.task_id¶ (
int
) – Internal variable used to increment update id. Shouldn’t be set by user. Defaults to 0.clusters_per_task¶ (
int
) – Number clusters in k-means used for task identification. Defaults to 5.per_task_classifier¶ (
bool
) – Flag to share or use a common classifier head for all tasks. Defaults to False.
- class renate.benchmark.models.VisionTransformerB16(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.VisionTransformerB32(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.VisionTransformerCIFAR(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.VisionTransformerH14(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.VisionTransformerL16(**kwargs)[source]#
Bases:
VisionTransformer
- class renate.benchmark.models.VisionTransformerL32(**kwargs)[source]#
Bases:
VisionTransformer