Kernel Distances.

This commit is contained in:
Alexander Engelsberger 2021-04-27 12:06:15 +02:00
parent 40751aa50a
commit 7d353f5b5a
2 changed files with 50 additions and 0 deletions

View File

@ -261,5 +261,34 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
return diss.permute([1, 0, 2]).squeeze(-1)
class KernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy)
class SquaredKernelDistance(KernelDistance):
r"""Squared Kernel Distance
Kernel distance without final squareroot.
"""
def __call__(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
return kappa_xx - 2 * kappa_xy + kappa_yy
# Aliases
sed = squared_euclidean_distance

View File

@ -0,0 +1,21 @@
"""
Experimental Kernels
"""
import torch
class ExplicitKernel:
def __init__(self, projection=torch.nn.Identity()):
self.projection = projection
def __call__(self, x, y):
return self.projection(x) @ self.projection(y)
class RadialBasisFunctionKernel:
def __init__(self, sigma) -> None:
self.s2 = sigma * sigma
def __call__(self, x, y):
return torch.exp(-torch.sum((x - y)**2) / (2 * self.s2))