renate.shift.detector module#
- class renate.shift.detector.ShiftDetector(batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#
Bases:
object
Base class for distribution shift detectors.
The main interface consists of two methods
fit
andscore
, which expect pytorch Dataset objects. One passes a reference dataset to thefit
method. Then we can check query datasets for distribution shifts (relative to the reference dataset) using thescore
method. Thescore
method returns a scalar shift score with the convention that high values indicate a distribution shift. For most methods, this score will be in [0, 1].- Parameters:
batch_size¶ (
int
) – Batch size used to iterate over datasets, e.g., for extracting features. This choice does not affect the result of the shift detector, but might affect run time.num_preprocessing_workers¶ (
int
) – Number of workers used in data loaders.device¶ (
str
) – Device to use for computations inside the detector.
- class renate.shift.detector.ShiftDetectorWithFeatureExtractor(feature_extractor=None, batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#
Bases:
ShiftDetector
Base class for detectors working on extracted features.
These shift detectors extract some (lower-dimensional) features from the datasets, which are used as inputs to the shift detection methods. Subclasses have to overwrite
fit_with_features
andscore_with_features
.- Parameters: