From b0cd2de18e8350771b6ad08734fe9ed79b49e08c Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 27 Apr 2021 15:38:34 +0200 Subject: [PATCH] Batch Kernel. [Ineficient] --- prototorch/functions/distances.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/prototorch/functions/distances.py b/prototorch/functions/distances.py index c5d4699..5949d0b 100644 --- a/prototorch/functions/distances.py +++ b/prototorch/functions/distances.py @@ -269,12 +269,37 @@ class KernelDistance: def __init__(self, kernel_fn): self.kernel_fn = kernel_fn - def __call__(self, x, y): + def __call__(self, x_batch, y_batch): + remove_dims = 0 + # Extend Single inputs + if len(x_batch.shape) == 1: + x_batch = [x_batch] + remove_dims += 1 + if len(y_batch.shape) == 1: + y_batch = [y_batch] + remove_dims += 1 + + # Loop over batches + output = [] + for x in x_batch: + output.append([]) + for y in y_batch: + output[-1].append(self.single_call(x, y)) + + output = torch.Tensor(output) + for _ in range(remove_dims): + output.squeeze_(0) + + return output + + def single_call(self, x, y): kappa_xx = self.kernel_fn(x, x) kappa_xy = self.kernel_fn(x, y) kappa_yy = self.kernel_fn(y, y) - return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy) + squared_distance = kappa_xx - 2 * kappa_xy + kappa_yy + + return torch.sqrt(squared_distance) class SquaredKernelDistance(KernelDistance): @@ -282,7 +307,7 @@ class SquaredKernelDistance(KernelDistance): Kernel distance without final squareroot. """ - def __call__(self, x, y): + def single_call(self, x, y): kappa_xx = self.kernel_fn(x, x) kappa_xy = self.kernel_fn(x, y) kappa_yy = self.kernel_fn(y, y)