Source code for renate.shift.mmd_helpers

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple

import torch

from renate.shift.kernels import Kernel


[docs] def mmd_gram( K: torch.Tensor, z: torch.Tensor, num_permutations: int = 0 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Maximum mean discrepancy based on a precomputed kernel-Gram matrix. This computes the test statistic and (optionally) p-value to conduct an MMD two-sample test to decide whether two sets are generated by the same distribution. The inputs are passed implicitly in the form of a kernel Gram matrix, evaluated across the union of both sets, and a binary vector indicating the assignments of data points to the two sets. I.e., a value of `z[i] = 0` indicates that the `i`-th data point belongs to set zero. Optionally, a permutation test is carried out and a p-value is returned alongside the raw test statistic. MMD tests have been proposed by [1] Gretton, A., et al. A kernel two-sample test. JMLR (2012). Args: K: A tensor containing a kernel-Gram matrix (size `(n, n)`). z: A binary vector of length `n` indicating the partition. num_permutations: If this is a positive number, a permutation test will be carried out and an approximate p-value will be returned. Returns: A tuple `(t, p)` of two scalar floats, where `t` is the value of the MMD test statistic and `p` is the p-value (or `None` if `num_permutations=0`). """ n = K.size(0) assert K.size(1) == n assert z.size() == (n,) inds_0 = torch.where(z == 0)[0] inds_1 = torch.where(z == 1)[0] n0 = len(inds_0) n1 = len(inds_1) mmd = ( K[inds_0][:, inds_0].sum() / (n0 * (n0 - 1)) + K[inds_1][:, inds_1].sum() / (n1 * (n1 - 1)) - 2 * K[inds_0][:, inds_1].mean() ) # MMD statistic, see Eq. (5) in [1]. if num_permutations == 0: return mmd, None # Permutation test: Randomize the assignments z, compute MMD, and count how often we exceed the # value obtained with original z. cnt = 0 for _ in range(num_permutations): z_ = torch.zeros(n, device=K.device) pi = torch.randperm(n, device=K.device) z_[pi[:n1]] = 1 mmd0, _ = mmd_gram(K, z_, num_permutations=0) if mmd0 > mmd: cnt += 1 p_val = torch.tensor(cnt / num_permutations) return mmd, p_val
[docs] def mmd( X0: torch.Tensor, X1: torch.Tensor, kernel: Kernel, num_permutations: int = 0 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Compute MMD between two samples. Optionally, return an estimated p-value based on a permutation test. Args: X0: First sample, shape (n0, dx). X1: Second sample, shape (n1, dx). kernel: Kernel function to use. num_permutations: If this is a positive number, a permutation test will be carried out and a p-value will be returned. Returns: A tuple `(t, p)` where `t` is the value of the MMD test statistic and `p` is the p-value (or `None` if `num_permutations=0`). """ X = torch.cat([X0, X1], dim=0) z = torch.cat([torch.zeros(X0.size(0)), torch.ones(X1.size(0))], dim=0) K = kernel(X, X) assert not torch.any(torch.isnan(K)) return mmd_gram(K, z, num_permutations)