diff --git a/prototorch/functions/distances.py b/prototorch/functions/distances.py index 961cd69..c5d4699 100644 --- a/prototorch/functions/distances.py +++ b/prototorch/functions/distances.py @@ -261,5 +261,34 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10): return diss.permute([1, 0, 2]).squeeze(-1) +class KernelDistance: + r"""Kernel Distance + + Distance based on a kernel function. + """ + def __init__(self, kernel_fn): + self.kernel_fn = kernel_fn + + def __call__(self, x, y): + kappa_xx = self.kernel_fn(x, x) + kappa_xy = self.kernel_fn(x, y) + kappa_yy = self.kernel_fn(y, y) + + return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy) + + +class SquaredKernelDistance(KernelDistance): + r"""Squared Kernel Distance + + Kernel distance without final squareroot. + """ + def __call__(self, x, y): + kappa_xx = self.kernel_fn(x, x) + kappa_xy = self.kernel_fn(x, y) + kappa_yy = self.kernel_fn(y, y) + + return kappa_xx - 2 * kappa_xy + kappa_yy + + # Aliases sed = squared_euclidean_distance diff --git a/prototorch/functions/kernels.py b/prototorch/functions/kernels.py new file mode 100644 index 0000000..6773c8e --- /dev/null +++ b/prototorch/functions/kernels.py @@ -0,0 +1,21 @@ +""" +Experimental Kernels +""" + +import torch + + +class ExplicitKernel: + def __init__(self, projection=torch.nn.Identity()): + self.projection = projection + + def __call__(self, x, y): + return self.projection(x) @ self.projection(y) + + +class RadialBasisFunctionKernel: + def __init__(self, sigma) -> None: + self.s2 = sigma * sigma + + def __call__(self, x, y): + return torch.exp(-torch.sum((x - y)**2) / (2 * self.s2))