# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
from renate.shift.kernels import Kernel
[docs]
def mmd_gram(
K: torch.Tensor, z: torch.Tensor, num_permutations: int = 0
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Maximum mean discrepancy based on a precomputed kernel-Gram matrix.
This computes the test statistic and (optionally) p-value to conduct an MMD two-sample test to
decide whether two sets are generated by the same distribution. The inputs are passed implicitly
in the form of a kernel Gram matrix, evaluated across the union of both sets, and a binary
vector indicating the assignments of data points to the two sets. I.e., a value of `z[i] = 0`
indicates that the `i`-th data point belongs to set zero. Optionally, a permutation test is
carried out and a p-value is returned alongside the raw test statistic.
MMD tests have been proposed by
[1] Gretton, A., et al. A kernel two-sample test. JMLR (2012).
Args:
K: A tensor containing a kernel-Gram matrix (size `(n, n)`).
z: A binary vector of length `n` indicating the partition.
num_permutations: If this is a positive number, a permutation test will be carried out and
an approximate p-value will be returned.
Returns:
A tuple `(t, p)` of two scalar floats, where `t` is the value of the MMD test statistic and
`p` is the p-value (or `None` if `num_permutations=0`).
"""
n = K.size(0)
assert K.size(1) == n
assert z.size() == (n,)
inds_0 = torch.where(z == 0)[0]
inds_1 = torch.where(z == 1)[0]
n0 = len(inds_0)
n1 = len(inds_1)
mmd = (
K[inds_0][:, inds_0].sum() / (n0 * (n0 - 1))
+ K[inds_1][:, inds_1].sum() / (n1 * (n1 - 1))
- 2 * K[inds_0][:, inds_1].mean()
) # MMD statistic, see Eq. (5) in [1].
if num_permutations == 0:
return mmd, None
# Permutation test: Randomize the assignments z, compute MMD, and count how often we exceed the
# value obtained with original z.
cnt = 0
for _ in range(num_permutations):
z_ = torch.zeros(n, device=K.device)
pi = torch.randperm(n, device=K.device)
z_[pi[:n1]] = 1
mmd0, _ = mmd_gram(K, z_, num_permutations=0)
if mmd0 > mmd:
cnt += 1
p_val = torch.tensor(cnt / num_permutations)
return mmd, p_val
[docs]
def mmd(
X0: torch.Tensor, X1: torch.Tensor, kernel: Kernel, num_permutations: int = 0
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Compute MMD between two samples.
Optionally, return an estimated p-value based on a permutation test.
Args:
X0: First sample, shape (n0, dx).
X1: Second sample, shape (n1, dx).
kernel: Kernel function to use.
num_permutations: If this is a positive number, a permutation test will
be carried out and a p-value will be returned.
Returns:
A tuple `(t, p)` where `t` is the value of the MMD test statistic and `p` is the p-value
(or `None` if `num_permutations=0`).
"""
X = torch.cat([X0, X1], dim=0)
z = torch.cat([torch.zeros(X0.size(0)), torch.ones(X1.size(0))], dim=0)
K = kernel(X, X)
assert not torch.any(torch.isnan(K))
return mmd_gram(K, z, num_permutations)