renate.benchmark.models.spromptmodel module#

class renate.benchmark.models.spromptmodel.PromptPool(prompt_size=10, embedding_size=768, current_update_id=0)[source]#

Bases: Module

Implements a pool of prompts to be used in for S-Prompts.

Parameters:
  • prompt_size (int) – Equivalent to number of input tokens used per update . Defaults to 10.

  • embedding_size (int) – Hidden size of the transformer used.. Defaults to 768.

  • current_update_id (int) – Current update it. Used to init number of prompts. Defaults to 0.

forward(id)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses. :rtype: Parameter

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_params(id)[source]#
Return type:

List[Parameter]

increment_task()[source]#
Return type:

None

training: bool#
class renate.benchmark.models.spromptmodel.SPromptTransformer(pretrained_model_name_or_path='google/vit-base-patch16-224', image_size=32, patch_size=4, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, dropout=0.1, attention_dropout=0.1, num_outputs=10, prediction_strategy=None, add_icarl_class_means=True, prompt_size=10, task_id=0, clusters_per_task=5, per_task_classifier=False)[source]#

Bases: RenateBenchmarkingModule

Implements Transformer Model for S-Prompts as described in Wang, Yabin, et.al .”S-prompts

learning with pre-trained transformers: An occam’s razor for domain incremental learning.” Advances in Neural Information Processing Systems 35 (2022): 5682-5695.

Parameters:
  • pretrained_model_name_or_path – A string that denotes which pretrained model from the HF hub to use.

  • image_size (int) – Image size. Used if pretrained_model_name_or_path is not set .

  • patch_size (int) – Patch size to be extracted. Used if pretrained_model_name_or_path is not set .

  • num_layers (int) – Num of transformer layers. Used only if pretrained_model_name_or_path is not set .

  • num_heads (int) – Num heads in MHSA. Used only if pretrained_model_name_or_path is not set .

  • hidden_dim (int) – Hidden dimension of transformers. Used only if pretrained_model_name_or_path is not set .

  • mlp_dim (int) – _description_. Used only if pretrained_model_name_or_path is not set .

  • dropout (float) – _description_. Used only if pretrained_model_name_or_path is not set .

  • attention_dropout (float) – _description_. Used only if pretrained_model_name_or_path is not set .

  • num_outputs (int) – Number of output classes of the output. Defaults to 10.

  • prediction_strategy (Optional[PredictionStrategy]) – Continual learning strategies may alter the prediction at train or test time. Defaults to None.

  • add_icarl_class_means (bool) – If True, additional parameters used only by the ICaRLModelUpdater are added. Only required when using that updater.

  • prompt_size (int) – Equivalent to number of input tokens used per update . Defaults to 10.

  • task_id (int) – Internal variable used to increment update id. Shouldn’t be set by user. Defaults to 0.

  • clusters_per_task (int) – Number clusters in k-means used for task identification. Defaults to 5.

  • per_task_classifier (bool) – Flag to share or use a common classifier head for all tasks. Defaults to False.

increment_task()[source]#
Return type:

None

update_task_identifier(features, labels)[source]#
Return type:

None

set_extra_state(state, decode=True)[source]#

Extract the content of the _extra_state and set the related values in the module.

decode flag is to decode the tensor of pkl bytes.

features(x)[source]#
Return type:

Tensor

training: bool#