Source code for renate.models.prediction_strategies
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch import Tensor
[docs]
class PredictionStrategy(ABC):
@abstractmethod
def __call__(self, inputs: Tensor, training: bool, **kwargs: Any) -> Tensor:
return inputs
[docs]
class ICaRLClassificationStrategy(PredictionStrategy):
def __call__(self, inputs: Tensor, training: bool, class_means: Tensor) -> Tensor:
if training:
return super().__call__(inputs, training)
normalized_inputs = (inputs.T / torch.norm(inputs.T, dim=0)).T
return (-torch.cdist(class_means.to(normalized_inputs.device)[:, :].T, normalized_inputs)).T