renate.shift.detector module#
- class renate.shift.detector.ShiftDetector(batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#
Bases:
object
Base class for distribution shift detectors.
The main interface consists of two methods fit and score, which expect pytorch Dataset objects. One passes a reference dataset to the fit method. Then we can check query datasets for distribution shifts (relative to the reference dataset) using the score method. The score method returns a scalar shift score with the convention that high values indicate a distribution shift. For most methods, this score will be in [0, 1].
- Parameters:
batch_size¶ (
int
) – Batch size used to iterate over datasets, e.g., for extracting features. This choice does not affect the result of the shift detector, but might affect run time.num_preprocessing_workers¶ (
int
) – Number of workers used in data loaders.device¶ (
str
) – Device to use for computations inside the detector.
- class renate.shift.detector.ShiftDetectorWithFeatureExtractor(feature_extractor=None, batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#
Bases:
ShiftDetector
Base class for detectors working on extracted features.
These shift detectors extract some (lower-dimensional) features from the datasets, which are used as inputs to the shift detection methods. Subclasses have to overwrite fit_with_features and score_with_features.
- Parameters: