renate.shift.mmd_detectors module#

class renate.shift.mmd_detectors.MMDCovariateShiftDetector(feature_extractor=None, num_permutations=1000, batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#

Bases: ShiftDetectorWithFeatureExtractor

A kernel maximum mean discrepancy (MMD) test.

This test was proposed by

[1] Gretton, A., et al. A kernel two-sample test. JMLR (2012).

We currently do not expose the choice of kernel. It defaults to an RBF kernel with a lengthscale set via the median heuristic.

The detector computes an approximate p-value via a permutation test. The score method returns 1 - p_value to conform to the convention that high scores indicate a shift.

Parameters:
  • feature_extractor (Optional[Module]) – A pytorch model used as feature extractor.

  • num_permutations (int) – Number of permutations for permutation test.

  • batch_size (int) – Batch size used to iterate over datasets.

  • num_preprocessing_workers (int) – Number of workers used in data loaders.

  • device (str) – Device to use for computations inside the detector.