Kernel Distances.
This commit is contained in:
		@@ -261,5 +261,34 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
 | 
			
		||||
            return diss.permute([1, 0, 2]).squeeze(-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KernelDistance:
 | 
			
		||||
    r"""Kernel Distance
 | 
			
		||||
 | 
			
		||||
    Distance based on a kernel function.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, kernel_fn):
 | 
			
		||||
        self.kernel_fn = kernel_fn
 | 
			
		||||
 | 
			
		||||
    def __call__(self, x, y):
 | 
			
		||||
        kappa_xx = self.kernel_fn(x, x)
 | 
			
		||||
        kappa_xy = self.kernel_fn(x, y)
 | 
			
		||||
        kappa_yy = self.kernel_fn(y, y)
 | 
			
		||||
 | 
			
		||||
        return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SquaredKernelDistance(KernelDistance):
 | 
			
		||||
    r"""Squared Kernel Distance
 | 
			
		||||
 | 
			
		||||
    Kernel distance without final squareroot.
 | 
			
		||||
    """
 | 
			
		||||
    def __call__(self, x, y):
 | 
			
		||||
        kappa_xx = self.kernel_fn(x, x)
 | 
			
		||||
        kappa_xy = self.kernel_fn(x, y)
 | 
			
		||||
        kappa_yy = self.kernel_fn(y, y)
 | 
			
		||||
 | 
			
		||||
        return kappa_xx - 2 * kappa_xy + kappa_yy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Aliases
 | 
			
		||||
sed = squared_euclidean_distance
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										21
									
								
								prototorch/functions/kernels.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								prototorch/functions/kernels.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
			
		||||
"""
 | 
			
		||||
Experimental Kernels
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExplicitKernel:
 | 
			
		||||
    def __init__(self, projection=torch.nn.Identity()):
 | 
			
		||||
        self.projection = projection
 | 
			
		||||
 | 
			
		||||
    def __call__(self, x, y):
 | 
			
		||||
        return self.projection(x) @ self.projection(y)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RadialBasisFunctionKernel:
 | 
			
		||||
    def __init__(self, sigma) -> None:
 | 
			
		||||
        self.s2 = sigma * sigma
 | 
			
		||||
 | 
			
		||||
    def __call__(self, x, y):
 | 
			
		||||
        return torch.exp(-torch.sum((x - y)**2) / (2 * self.s2))
 | 
			
		||||
		Reference in New Issue
	
	Block a user