renate.shift.detector module#

class renate.shift.detector.ShiftDetector(batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#

Bases: object

Base class for distribution shift detectors.

The main interface consists of two methods fit and score, which expect pytorch Dataset objects. One passes a reference dataset to the fit method. Then we can check query datasets for distribution shifts (relative to the reference dataset) using the score method. The score method returns a scalar shift score with the convention that high values indicate a distribution shift. For most methods, this score will be in [0, 1].

Parameters:
  • batch_size (int) – Batch size used to iterate over datasets, e.g., for extracting features. This choice does not affect the result of the shift detector, but might affect run time.

  • num_preprocessing_workers (int) – Number of workers used in data loaders.

  • device (str) – Device to use for computations inside the detector.

fit(dataset)[source]#

Fit the detector to a reference dataset.

Return type:

None

score(dataset)[source]#

Compute distribution shift score for a query dataset.

Return type:

float

class renate.shift.detector.ShiftDetectorWithFeatureExtractor(feature_extractor=None, batch_size=32, num_preprocessing_workers=0, device='cpu')[source]#

Bases: ShiftDetector

Base class for detectors working on extracted features.

These shift detectors extract some (lower-dimensional) features from the datasets, which are used as inputs to the shift detection methods. Subclasses have to overwrite fit_with_features and score_with_features.

Parameters:
  • feature_extractor (Optional[Module]) – A pytorch model used as feature extractor.

  • batch_size (int) – Batch size used to iterate over datasets.

  • num_preprocessing_workers (int) – Number of workers used in data loaders.

  • device (str) – Device to use for computations inside the detector.

fit(dataset)[source]#

Fit the detector to a reference dataset.

Return type:

None

score(dataset)[source]#

Compute distribution shift score for a query dataset.

Return type:

float

extract_features(dataset)[source]#

Extract features from a dataset.

Return type:

Tensor