Distribution Shift Detection#

Retraining or updating of a machine learning model is usually necessitated by shifts in the distribution of data that is being served to the model. Renate provides methods for distribution shift detection that can help you decide when to update your model. This functionality resides in the renate.shift subpackage.

Shift Types#

In supervised machine learning tasks, one can distinguish different types of shifts in the joint distribution \(p(x, y)\). A common assumption is that of covariate shift, where we assume that \(p(x)\) changes while \(p(y|x)\) stays constant. In that case, one only needs to inspect \(x\) data to detect a shift. Currently, Renate only supports covariate shift detection.

Shift Detector Interface#

The shift detectors in renate.shift derive from a common class ShiftDetector, which defines the main interface. Once a detector object has been initialized, one calls detector.fit(dataset_ref) on a reference dataset (a PyTorch dataset object). This reference dataset characterizes the expected data distribution. It may, e.g., be the validation set used during the previous fitting of the model. Subsequently, we can score one or multiple query datasets using the detector.score(dataset_query) method. This method returns a scalar distribution shift score. We use the convention that high scores indicate a likely distribution shift. For all currently available models, this score lies between 0 and 1.

Available Methods#

At the moment, Renate provides two method for covariate shift detection

Both tests operate on features extracted from the raw data, which is passed using the feature_extractor argument at initialization. The feature extractor is expected to map the raw input data to informative vectorial representations of moderate dimension. It may be based on a pretrained model, e.g., by using its penultimate-layer embeddings (see also the example below).

Example#

The following example illustrates how to apply the MMD covariate shift detector. We will work with the CIFAR-10 dataset, which we can conveniently load using Renate’s TorchVisionDataModule. In practice, you would ingest your own data here, see the documentation for RenateDataModule.

data_module = TorchVisionDataModule(data_path="data", dataset_name="CIFAR10", val_size=0.2)
data_module.prepare_data()
data_module.setup()
dataset = data_module.train_data()

For the purpose of this demonstration, we now generate a reference dataset as well as two query datasets: one from the same distribution, and one where we simulate a distribution shift by blurring images. In practice, the reference dataset should represent your expected data distribution. It could, e.g., be the validation set you used during the previous training of your model. The query dataset would be the data you want to check for distribution shift, e.g., data collected during the deployment of your model.

dataset_ref = torch.utils.data.Subset(dataset, list(range(1000)))
dataset_query_in = torch.utils.data.Subset(dataset, list(range(1000, 2000)))
dataset_query_out = torch.utils.data.Subset(dataset, list(range(2000, 3000)))
transform = GaussianBlur(kernel_size=5, sigma=1.0)
dataset_query_out = _TransformedDataset(dataset_query_in, transform)

Shift detection methods rely on informative (and relatively low-dimensional) features. Here, we use a pretrained ResNet model and chop of its output layer. This leads to 512-dimensional vectorial features.

feature_extractor = resnet18(weights=ResNet18_Weights.DEFAULT)
feature_extractor.fc = torch.nn.Identity()
feature_extractor.eval()  # Eval mode to use frozen batchnorm stats.

You can use any torch.nn.Module, which may be a pretrained model or use a custom model that has been trained on the data at hand. Generally, we have observed very good result when using generic pre-trained models such as ResNets for image data or BERT models for text.

Now we can instantiate an MMD-based shift detector. We first fit it to our reference datasets and then score both the in-distribution query dataset as well as the out-of-distribution query dataset.

detector = MMDCovariateShiftDetector(feature_extractor=feature_extractor)
print("Fitting detector...")
detector.fit(dataset_ref)
print("Scoring in-distribution data...")
score_in = detector.score(dataset_query_in)
print(f"score = {score_in}")
print("Scoring out-of-distribution data...")
score_out = detector.score(dataset_query_out)
print(f"score = {score_out}")

In this toy example, the shift is quite obvious and we will see a very high score for the out-of-distribution data:

Fitting detector...
Scoring in-distribution data...
score = 0.5410000085830688
Scoring out-of-distribution data...
score = 1.0