Source code for renate.models.task_identification_strategies
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Union
import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
from sklearn.cluster import KMeans
[docs]
class TaskEstimator(nn.Module, ABC):
"""An ABC that all task estimator methods inherit.
They implement two methods `update_task_prototypes` and `infer_task`.
"""
[docs]
@abstractmethod
def update_task_prototypes(self):
return
[docs]
@abstractmethod
def infer_task(self):
return
[docs]
class TaskPrototypes(TaskEstimator):
"""Task identification method proposed in S-Prompts.
Args:
task_id: The current update id of the method. Required to deserialize.
clusters_per_task: Number of clusters to use in K-means.
embedding_size: Embedding size of the transformer features.
"""
def __init__(self, task_id, clusters_per_task, embedding_size) -> None:
super().__init__()
self.register_buffer(
"_training_feat_centroids",
torch.empty(task_id * clusters_per_task, embedding_size),
)
self.register_buffer(
"_training_feat_task_ids",
torch.full(
(self._training_feat_centroids.size(0),), fill_value=task_id, dtype=torch.long
),
)
self._clusters_per_task = clusters_per_task
self._task_id = task_id
self._embedding_size = embedding_size
[docs]
@torch.no_grad()
def update_task_prototypes(
self,
features: Union[torch.Tensor, npt.ArrayLike],
labels: Union[torch.Tensor, npt.ArrayLike],
) -> None:
# At training.
if isinstance(features, torch.Tensor):
features = features.cpu().numpy()
# l2 normalize features:
features = features / np.power(np.einsum("ij, ij -> i", features, features), 0.5)[:, None]
centroids = torch.from_numpy(
KMeans(n_clusters=self._clusters_per_task, random_state=0)
.fit(features)
.cluster_centers_
).to(self._training_feat_centroids.device)
self._training_feat_centroids = torch.cat(
[
self._training_feat_centroids,
centroids,
]
)
self._training_feat_task_ids = torch.cat(
[
self._training_feat_task_ids,
torch.full(
(centroids.size(0),),
fill_value=self._task_id,
dtype=torch.int8,
device=self._training_feat_task_ids.device,
),
]
)
[docs]
def infer_task(self, features: torch.Tensor) -> torch.Tensor:
# At inference.
if self._training_feat_centroids.numel() > 0:
features = torch.nn.functional.normalize(features)
nearest_p_inds = torch.cdist(features, self._training_feat_centroids, p=2).argmin(1)
return self._training_feat_task_ids[nearest_p_inds]
else:
return None