Fix RBF Kernel Dimensions.
This commit is contained in:
		@@ -18,4 +18,11 @@ class RadialBasisFunctionKernel:
 | 
			
		||||
        self.s2 = sigma * sigma
 | 
			
		||||
 | 
			
		||||
    def __call__(self, x, y):
 | 
			
		||||
        return torch.exp(-torch.sum((x - y)**2) / (2 * self.s2))
 | 
			
		||||
        remove_dim = False
 | 
			
		||||
        if len(x.shape) > 1:
 | 
			
		||||
            x = x.unsqueeze(1)
 | 
			
		||||
            remove_dim = True
 | 
			
		||||
        output = torch.exp(-torch.sum((x - y)**2, dim=-1) / (2 * self.s2))
 | 
			
		||||
        if remove_dim:
 | 
			
		||||
            output = output.squeeze(1)
 | 
			
		||||
        return output
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,7 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from prototorch.functions.distances import KernelDistance
 | 
			
		||||
from prototorch.functions.kernels import ExplicitKernel
 | 
			
		||||
from prototorch.functions.kernels import ExplicitKernel, RadialBasisFunctionKernel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestExplicitKernel(unittest.TestCase):
 | 
			
		||||
@@ -17,8 +17,6 @@ class TestExplicitKernel(unittest.TestCase):
 | 
			
		||||
        self.batch_x = torch.randn(32, 1024)
 | 
			
		||||
        self.batch_y = torch.randn(32, 1024)
 | 
			
		||||
 | 
			
		||||
        self.kernel = ExplicitKernel()
 | 
			
		||||
 | 
			
		||||
    def test_single_values(self):
 | 
			
		||||
        kernel = ExplicitKernel()
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
@@ -40,6 +38,35 @@ class TestExplicitKernel(unittest.TestCase):
 | 
			
		||||
            kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRadialBasisFunctionKernel(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.single_x = torch.randn(1024)
 | 
			
		||||
        self.single_y = torch.randn(1024)
 | 
			
		||||
 | 
			
		||||
        self.batch_x = torch.randn(32, 1024)
 | 
			
		||||
        self.batch_y = torch.randn(32, 1024)
 | 
			
		||||
 | 
			
		||||
    def test_single_values(self):
 | 
			
		||||
        kernel = RadialBasisFunctionKernel(1)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            kernel(self.single_x, self.single_y).shape, torch.Size([]))
 | 
			
		||||
 | 
			
		||||
    def test_single_batch(self):
 | 
			
		||||
        kernel = RadialBasisFunctionKernel(1)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
 | 
			
		||||
 | 
			
		||||
    def test_batch_single(self):
 | 
			
		||||
        kernel = RadialBasisFunctionKernel(1)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
 | 
			
		||||
 | 
			
		||||
    def test_batch_values(self):
 | 
			
		||||
        kernel = RadialBasisFunctionKernel(1)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestKernelDistance(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.single_x = torch.randn(1024)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user