Fix kernel dimensions.

This commit is contained in:
Alexander Engelsberger 2021-04-27 16:56:56 +02:00
parent ba537fe1d5
commit 209f9e641b
3 changed files with 110 additions and 12 deletions

View File

@ -272,30 +272,57 @@ class KernelDistance:
def __init__(self, kernel_fn): def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn self.kernel_fn = kernel_fn
def __call__(self, x_batch, y_batch): def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
return self._single_call(x_batch, y_batch)
def _single_call(self, x, y):
remove_dims = []
if len(x.shape) == 1:
x = x.unsqueeze(0)
remove_dims.append(0)
if len(y.shape) == 1:
y = y.unsqueeze(0)
remove_dims.append(-1)
output = self.kernel_fn(x, x).diag().unsqueeze(1) - 2 * self.kernel_fn(
x, y) + self.kernel_fn(y, y).diag()
for dim in remove_dims:
output.squeeze_(dim)
return torch.sqrt(output)
class BatchKernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
remove_dims = 0 remove_dims = 0
# Extend Single inputs # Extend Single inputs
if len(x_batch.shape) == 1: if len(x_batch.shape) == 1:
x_batch = [x_batch] x_batch = x_batch.unsqueeze(0)
remove_dims += 1 remove_dims += 1
if len(y_batch.shape) == 1: if len(y_batch.shape) == 1:
y_batch = [y_batch] y_batch = y_batch.unsqueeze(0)
remove_dims += 1 remove_dims += 1
# Loop over batches # Loop over batches
output = [] output = torch.FloatTensor(len(x_batch), len(y_batch))
for x in x_batch: for i, x in enumerate(x_batch):
output.append([]) for j, y in enumerate(y_batch):
for y in y_batch: output[i][j] = self._single_call(x, y)
output[-1].append(self.single_call(x, y))
output = torch.Tensor(output)
for _ in range(remove_dims): for _ in range(remove_dims):
output.squeeze_(0) output.squeeze_(0)
return output return output
def single_call(self, x, y): def _single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x) kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y) kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y) kappa_yy = self.kernel_fn(y, y)

View File

@ -10,7 +10,7 @@ class ExplicitKernel:
self.projection = projection self.projection = projection
def __call__(self, x, y): def __call__(self, x, y):
return self.projection(x) @ self.projection(y) return self.projection(x) @ self.projection(y).T
class RadialBasisFunctionKernel: class RadialBasisFunctionKernel:

71
tests/test_kernels.py Normal file
View File

@ -0,0 +1,71 @@
"""ProtoTorch kernels test suite."""
import unittest
import numpy as np
import torch
from prototorch.functions.distances import KernelDistance
from prototorch.functions.kernels import ExplicitKernel
class TestExplicitKernel(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
self.kernel = ExplicitKernel()
def test_single_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
class TestKernelDistance(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
self.kernel = ExplicitKernel()
def test_single_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))