Distribution Shift Detection#
Retraining or updating of a machine learning model is usually necessitated by shifts in the
distribution of data that is being served to the model.
Renate provides methods for distribution shift detection that can help you decide when to update
your model.
This functionality resides in the renate.shift
subpackage.
Shift Types#
In supervised machine learning tasks, one can distinguish different types of shifts in the joint distribution \(p(x, y)\). A common assumption is that of covariate shift, where we assume that \(p(x)\) changes while \(p(y|x)\) stays constant. In that case, one only needs to inspect \(x\) data to detect a shift. Currently, Renate only supports covariate shift detection.
Shift Detector Interface#
The shift detectors in renate.shift
derive from a common class
ShiftDetector
, which defines the main interface. Once a
detector
object has been initialized, one calls detector.fit(dataset_ref)
on a
reference dataset (a PyTorch dataset object). This reference dataset characterizes the expected
data distribution. It may, e.g., be the validation set used during the previous fitting of the
model. Subsequently, we can score one or multiple query datasets using the
detector.score(dataset_query)
method. This method returns a scalar distribution shift score.
We use the convention that high scores indicate a likely distribution shift. For all currently
available models, this score lies between 0 and 1.
Available Methods#
At the moment, Renate provides two method for covariate shift detection
MMDCovariateShiftDetector
uses a multivariate kernel MMD test.KolmogorovSmirnovCovariateShiftDetector
uses a univariate Kolmogorov-Smirnov test on each feature, aggregated with a Bonferroni correction.
Both tests operate on features extracted from the raw data, which is passed using the
feature_extractor
argument at initialization. The feature extractor is expected to map the
raw input data to informative vectorial representations of moderate dimension. It may be based on
a pretrained model, e.g., by using its penultimate-layer embeddings (see also the example below).
Example#
The following example illustrates how to apply the MMD covariate shift detector.
We will work with the CIFAR-10 dataset, which we can conveniently load using Renate’s
TorchVisionDataModule
.
In practice, you would ingest your own data here, see the documentation for
RenateDataModule
.
data_module = TorchVisionDataModule(data_path="data", dataset_name="CIFAR10", val_size=0.2)
data_module.prepare_data()
data_module.setup()
dataset = data_module.train_data()
For the purpose of this demonstration, we now generate a reference dataset as well as two query datasets: one from the same distribution, and one where we simulate a distribution shift by blurring images. In practice, the reference dataset should represent your expected data distribution. It could, e.g., be the validation set you used during the previous training of your model. The query dataset would be the data you want to check for distribution shift, e.g., data collected during the deployment of your model.
dataset_ref = torch.utils.data.Subset(dataset, list(range(1000)))
dataset_query_in = torch.utils.data.Subset(dataset, list(range(1000, 2000)))
dataset_query_out = torch.utils.data.Subset(dataset, list(range(2000, 3000)))
transform = GaussianBlur(kernel_size=5, sigma=1.0)
dataset_query_out = _TransformedDataset(dataset_query_in, transform)
Shift detection methods rely on informative (and relatively low-dimensional) features. Here, we use a pretrained ResNet model and chop of its output layer. This leads to 512-dimensional vectorial features.
feature_extractor = resnet18(weights=ResNet18_Weights.DEFAULT)
feature_extractor.fc = torch.nn.Identity()
feature_extractor.eval() # Eval mode to use frozen batchnorm stats.
You can use any torch.nn.Module
, which may be a pretrained model or use a custom model
that has been trained on the data at hand.
Generally, we have observed very good result when using generic pre-trained models such as ResNets
for image data or BERT models for text.
Now we can instantiate an MMD-based shift detector. We first fit it to our reference datasets and then score both the in-distribution query dataset as well as the out-of-distribution query dataset.
detector = MMDCovariateShiftDetector(feature_extractor=feature_extractor)
print("Fitting detector...")
detector.fit(dataset_ref)
print("Scoring in-distribution data...")
score_in = detector.score(dataset_query_in)
print(f"score = {score_in}")
print("Scoring out-of-distribution data...")
score_out = detector.score(dataset_query_out)
print(f"score = {score_out}")
In this toy example, the shift is quite obvious and we will see a very high score for the out-of-distribution data:
Fitting detector...
Scoring in-distribution data...
score = 0.5410000085830688
Scoring out-of-distribution data...
score = 1.0