renate.shift.mmd_detectors module#
- class renate.shift.mmd_detectors.MMDCovariateShiftDetector(feature_extractor=None, num_permutations=1000, batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#
Bases:
ShiftDetectorWithFeatureExtractor
A kernel maximum mean discrepancy (MMD) test.
This test was proposed by
[1] Gretton, A., et al. A kernel two-sample test. JMLR (2012).
We currently do not expose the choice of kernel. It defaults to an RBF kernel with a lengthscale set via the median heuristic.
The detector computes an approximate p-value via a permutation test. The
score
method returns1 - p_value
to conform to the convention that high scores indicate a shift.- Parameters:
feature_extractor¶ (
Optional
[Module
]) – A pytorch model used as feature extractor.num_permutations¶ (
int
) – Number of permutations for permutation test.batch_size¶ (
int
) – Batch size used to iterate over datasets.num_preprocessing_workers¶ (
int
) – Number of workers used in data loaders.device¶ (
str
) – Device to use for computations inside the detector.