Source code for renate.shift.mmd_detectors
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from renate.shift.detector import ShiftDetectorWithFeatureExtractor
from renate.shift.kernels import RBFKernel
from renate.shift.mmd_helpers import mmd
[docs]
class MMDCovariateShiftDetector(ShiftDetectorWithFeatureExtractor):
"""A kernel maximum mean discrepancy (MMD) test.
This test was proposed by
[1] Gretton, A., et al. A kernel two-sample test. JMLR (2012).
We currently do not expose the choice of kernel. It defaults to an RBF kernel with a lengthscale
set via the median heuristic.
The detector computes an approximate p-value via a permutation test. The `score` method returns
`1 - p_value` to conform to the convention that high scores indicate a shift.
Args:
feature_extractor: A pytorch model used as feature extractor.
num_permutations: Number of permutations for permutation test.
batch_size: Batch size used to iterate over datasets.
num_preprocessing_workers: Number of workers used in data loaders.
device: Device to use for computations inside the detector.
"""
def __init__(
self,
feature_extractor: Optional[torch.nn.Module] = None,
num_permutations: int = 1000,
batch_size: int = 32,
num_preprocessing_workers: int = 0,
device: str = "cpu",
) -> None:
super().__init__(feature_extractor, batch_size, num_preprocessing_workers, device)
self._num_permutations = num_permutations
def _fit_with_features(self, X: torch.Tensor):
self._X_ref = X
def _score_with_features(self, X: torch.Tensor) -> float:
_, p_val = mmd(self._X_ref, X, kernel=RBFKernel(), num_permutations=self._num_permutations)
return 1.0 - p_val.item()