Fix RBF Kernel Dimensions.
This commit is contained in:
parent
209f9e641b
commit
65e0637b17
@ -18,4 +18,11 @@ class RadialBasisFunctionKernel:
|
|||||||
self.s2 = sigma * sigma
|
self.s2 = sigma * sigma
|
||||||
|
|
||||||
def __call__(self, x, y):
|
def __call__(self, x, y):
|
||||||
return torch.exp(-torch.sum((x - y)**2) / (2 * self.s2))
|
remove_dim = False
|
||||||
|
if len(x.shape) > 1:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
remove_dim = True
|
||||||
|
output = torch.exp(-torch.sum((x - y)**2, dim=-1) / (2 * self.s2))
|
||||||
|
if remove_dim:
|
||||||
|
output = output.squeeze(1)
|
||||||
|
return output
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.distances import KernelDistance
|
from prototorch.functions.distances import KernelDistance
|
||||||
from prototorch.functions.kernels import ExplicitKernel
|
from prototorch.functions.kernels import ExplicitKernel, RadialBasisFunctionKernel
|
||||||
|
|
||||||
|
|
||||||
class TestExplicitKernel(unittest.TestCase):
|
class TestExplicitKernel(unittest.TestCase):
|
||||||
@ -17,8 +17,6 @@ class TestExplicitKernel(unittest.TestCase):
|
|||||||
self.batch_x = torch.randn(32, 1024)
|
self.batch_x = torch.randn(32, 1024)
|
||||||
self.batch_y = torch.randn(32, 1024)
|
self.batch_y = torch.randn(32, 1024)
|
||||||
|
|
||||||
self.kernel = ExplicitKernel()
|
|
||||||
|
|
||||||
def test_single_values(self):
|
def test_single_values(self):
|
||||||
kernel = ExplicitKernel()
|
kernel = ExplicitKernel()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -40,6 +38,35 @@ class TestExplicitKernel(unittest.TestCase):
|
|||||||
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
|
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
|
||||||
|
|
||||||
|
|
||||||
|
class TestRadialBasisFunctionKernel(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.single_x = torch.randn(1024)
|
||||||
|
self.single_y = torch.randn(1024)
|
||||||
|
|
||||||
|
self.batch_x = torch.randn(32, 1024)
|
||||||
|
self.batch_y = torch.randn(32, 1024)
|
||||||
|
|
||||||
|
def test_single_values(self):
|
||||||
|
kernel = RadialBasisFunctionKernel(1)
|
||||||
|
self.assertEqual(
|
||||||
|
kernel(self.single_x, self.single_y).shape, torch.Size([]))
|
||||||
|
|
||||||
|
def test_single_batch(self):
|
||||||
|
kernel = RadialBasisFunctionKernel(1)
|
||||||
|
self.assertEqual(
|
||||||
|
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
|
||||||
|
|
||||||
|
def test_batch_single(self):
|
||||||
|
kernel = RadialBasisFunctionKernel(1)
|
||||||
|
self.assertEqual(
|
||||||
|
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
|
||||||
|
|
||||||
|
def test_batch_values(self):
|
||||||
|
kernel = RadialBasisFunctionKernel(1)
|
||||||
|
self.assertEqual(
|
||||||
|
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
|
||||||
|
|
||||||
|
|
||||||
class TestKernelDistance(unittest.TestCase):
|
class TestKernelDistance(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.single_x = torch.randn(1024)
|
self.single_x = torch.randn(1024)
|
||||||
|
Loading…
Reference in New Issue
Block a user