Kernel Distances.
This commit is contained in:
parent
40751aa50a
commit
7d353f5b5a
@ -261,5 +261,34 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class KernelDistance:
|
||||||
|
r"""Kernel Distance
|
||||||
|
|
||||||
|
Distance based on a kernel function.
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_fn):
|
||||||
|
self.kernel_fn = kernel_fn
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
kappa_xx = self.kernel_fn(x, x)
|
||||||
|
kappa_xy = self.kernel_fn(x, y)
|
||||||
|
kappa_yy = self.kernel_fn(y, y)
|
||||||
|
|
||||||
|
return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy)
|
||||||
|
|
||||||
|
|
||||||
|
class SquaredKernelDistance(KernelDistance):
|
||||||
|
r"""Squared Kernel Distance
|
||||||
|
|
||||||
|
Kernel distance without final squareroot.
|
||||||
|
"""
|
||||||
|
def __call__(self, x, y):
|
||||||
|
kappa_xx = self.kernel_fn(x, x)
|
||||||
|
kappa_xy = self.kernel_fn(x, y)
|
||||||
|
kappa_yy = self.kernel_fn(y, y)
|
||||||
|
|
||||||
|
return kappa_xx - 2 * kappa_xy + kappa_yy
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
sed = squared_euclidean_distance
|
sed = squared_euclidean_distance
|
||||||
|
21
prototorch/functions/kernels.py
Normal file
21
prototorch/functions/kernels.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
Experimental Kernels
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ExplicitKernel:
|
||||||
|
def __init__(self, projection=torch.nn.Identity()):
|
||||||
|
self.projection = projection
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return self.projection(x) @ self.projection(y)
|
||||||
|
|
||||||
|
|
||||||
|
class RadialBasisFunctionKernel:
|
||||||
|
def __init__(self, sigma) -> None:
|
||||||
|
self.s2 = sigma * sigma
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return torch.exp(-torch.sum((x - y)**2) / (2 * self.s2))
|
Loading…
Reference in New Issue
Block a user