Source code for renate.shift.kernels

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import torch


[docs] class Kernel: """Base class for kernel functions.""" def __init__(self): pass def _check_inputs(self, X0: torch.Tensor, X1: torch.Tensor): assert X0.dim() == X1.dim() == 2 assert X0.size(1) == X1.size(1) assert X0.dtype is X1.dtype
[docs] class RBFKernel(Kernel): """A radial basis function kernel. This kernel has one hyperparameter, a scalar lengthscale. If this is set to `None` (default), the lengthscale will be set adaptively, at _each_ call to the kernel, via the median heuristic. Args: lengthscale: The kernel lengthscale. If `None` (default), this is set automatically via the median heuristic. Note: In this case, the lengthscale will be reset at each call to the kernel. """ def __init__(self, lengthscale: Optional[float] = None): super().__init__() self._lengthscale = lengthscale @torch.no_grad() def __call__(self, X0: torch.Tensor, X1: torch.Tensor): self._check_inputs(X0, X1) dists = torch.cdist(X0, X1) lengthscale = self._lengthscale or torch.median(dists) return torch.exp(-0.5 * dists**2 / lengthscale**2)