Source code for renate.shift.detector
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from torch.utils.data import DataLoader, Dataset
from renate.utils.pytorch import move_tensors_to_device
[docs]
class ShiftDetector:
"""Base class for distribution shift detectors.
The main interface consists of two methods `fit` and `score`, which expect pytorch Dataset
objects. One passes a reference dataset to the `fit` method. Then we can check query datasets
for distribution shifts (relative to the reference dataset) using the `score` method. The
`score` method returns a scalar shift score with the convention that high values indicate a
distribution shift. For most methods, this score will be in [0, 1].
Args:
batch_size: Batch size used to iterate over datasets, e.g., for extracting features. This
choice does not affect the result of the shift detector, but might affect run time.
num_preprocessing_workers: Number of workers used in data loaders.
device: Device to use for computations inside the detector.
"""
def __init__(
self,
batch_size: int = 32,
num_preprocessing_workers: int = 0,
device: str = "cpu",
) -> None:
self._batch_size = batch_size
self._num_preprocessing_workers = num_preprocessing_workers
self._device = device
[docs]
def fit(self, dataset: Dataset) -> None:
"""Fit the detector to a reference dataset."""
raise NotImplementedError()
[docs]
def score(self, dataset: Dataset) -> float:
"""Compute distribution shift score for a query dataset."""
raise NotImplementedError()
def _make_data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
"""Return a data loader to iterate over a dataset.
Args:
dataset: The dataset.
shuffle: Whether to shuffle or not.
"""
return DataLoader(
dataset,
batch_size=self._batch_size,
shuffle=shuffle,
num_workers=self._num_preprocessing_workers,
)
[docs]
class ShiftDetectorWithFeatureExtractor(ShiftDetector):
"""Base class for detectors working on extracted features.
These shift detectors extract some (lower-dimensional) features from the datasets, which are
used as inputs to the shift detection methods. Subclasses have to overwrite `fit_with_features`
and `score_with_features`.
Args:
feature_extractor: A pytorch model used as feature extractor.
batch_size: Batch size used to iterate over datasets.
num_preprocessing_workers: Number of workers used in data loaders.
device: Device to use for computations inside the detector.
"""
def __init__(
self,
feature_extractor: Optional[torch.nn.Module] = None,
batch_size: int = 32,
num_preprocessing_workers: int = 0,
device: str = "cpu",
) -> None:
super(ShiftDetectorWithFeatureExtractor, self).__init__(
batch_size, num_preprocessing_workers, device
)
self._feature_extractor = feature_extractor or torch.nn.Identity()
self._feature_extractor = self._feature_extractor.to(self._device)
[docs]
def fit(self, dataset: Dataset) -> None:
"""Fit the detector to a reference dataset."""
X = self.extract_features(dataset)
self._fit_with_features(X)
[docs]
def score(self, dataset: Dataset) -> float:
"""Compute distribution shift score for a query dataset."""
X = self.extract_features(dataset)
return self._score_with_features(X)
[docs]
@torch.no_grad()
def extract_features(self, dataset: Dataset) -> torch.Tensor:
"""Extract features from a dataset."""
dataloader = self._make_data_loader(dataset)
Xs = []
for batch in dataloader:
X = move_tensors_to_device(batch[0], device=self._device)
Xs.append(self._feature_extractor(X))
X = torch.cat(Xs, dim=0).cpu()
return X
def _fit_with_features(self, X: torch.Tensor) -> None:
raise NotImplementedError()
def _score_with_features(self, X: torch.Tensor) -> float:
raise NotImplementedError()