renate.models.task_identification_strategies module#

class renate.models.task_identification_strategies.TaskEstimator[source]#

Bases: Module, ABC

An ABC that all task estimator methods inherit.

They implement two methods update_task_prototypes and infer_task.

abstract update_task_prototypes()[source]#
abstract infer_task()[source]#
training: bool#
class renate.models.task_identification_strategies.TaskPrototypes(task_id, clusters_per_task, embedding_size)[source]#

Bases: TaskEstimator

Task identification method proposed in S-Prompts.

Parameters:
  • task_id – The current update id of the method. Required to deserialize.

  • clusters_per_task – Number of clusters to use in K-means.

  • embedding_size – Embedding size of the transformer features.

update_task_prototypes(features, labels)[source]#
Return type:

None

infer_task(features)[source]#
Return type:

Tensor

training: bool#