Batch Kernel. [Ineficient]

This commit is contained in:
Alexander Engelsberger 2021-04-27 15:38:34 +02:00
parent 7d353f5b5a
commit b0cd2de18e

View File

@ -269,12 +269,37 @@ class KernelDistance:
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x, y):
def __call__(self, x_batch, y_batch):
remove_dims = 0
# Extend Single inputs
if len(x_batch.shape) == 1:
x_batch = [x_batch]
remove_dims += 1
if len(y_batch.shape) == 1:
y_batch = [y_batch]
remove_dims += 1
# Loop over batches
output = []
for x in x_batch:
output.append([])
for y in y_batch:
output[-1].append(self.single_call(x, y))
output = torch.Tensor(output)
for _ in range(remove_dims):
output.squeeze_(0)
return output
def single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy)
squared_distance = kappa_xx - 2 * kappa_xy + kappa_yy
return torch.sqrt(squared_distance)
class SquaredKernelDistance(KernelDistance):
@ -282,7 +307,7 @@ class SquaredKernelDistance(KernelDistance):
Kernel distance without final squareroot.
"""
def __call__(self, x, y):
def single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)