Kernel Distances.
This commit is contained in:
parent
40751aa50a
commit
7d353f5b5a
@ -261,5 +261,34 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||
|
||||
|
||||
class KernelDistance:
|
||||
r"""Kernel Distance
|
||||
|
||||
Distance based on a kernel function.
|
||||
"""
|
||||
def __init__(self, kernel_fn):
|
||||
self.kernel_fn = kernel_fn
|
||||
|
||||
def __call__(self, x, y):
|
||||
kappa_xx = self.kernel_fn(x, x)
|
||||
kappa_xy = self.kernel_fn(x, y)
|
||||
kappa_yy = self.kernel_fn(y, y)
|
||||
|
||||
return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy)
|
||||
|
||||
|
||||
class SquaredKernelDistance(KernelDistance):
|
||||
r"""Squared Kernel Distance
|
||||
|
||||
Kernel distance without final squareroot.
|
||||
"""
|
||||
def __call__(self, x, y):
|
||||
kappa_xx = self.kernel_fn(x, x)
|
||||
kappa_xy = self.kernel_fn(x, y)
|
||||
kappa_yy = self.kernel_fn(y, y)
|
||||
|
||||
return kappa_xx - 2 * kappa_xy + kappa_yy
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
||||
|
21
prototorch/functions/kernels.py
Normal file
21
prototorch/functions/kernels.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Experimental Kernels
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ExplicitKernel:
|
||||
def __init__(self, projection=torch.nn.Identity()):
|
||||
self.projection = projection
|
||||
|
||||
def __call__(self, x, y):
|
||||
return self.projection(x) @ self.projection(y)
|
||||
|
||||
|
||||
class RadialBasisFunctionKernel:
|
||||
def __init__(self, sigma) -> None:
|
||||
self.s2 = sigma * sigma
|
||||
|
||||
def __call__(self, x, y):
|
||||
return torch.exp(-torch.sum((x - y)**2) / (2 * self.s2))
|
Loading…
Reference in New Issue
Block a user