Batch Kernel. [Ineficient]
This commit is contained in:
parent
7d353f5b5a
commit
b0cd2de18e
@ -269,12 +269,37 @@ class KernelDistance:
|
||||
def __init__(self, kernel_fn):
|
||||
self.kernel_fn = kernel_fn
|
||||
|
||||
def __call__(self, x, y):
|
||||
def __call__(self, x_batch, y_batch):
|
||||
remove_dims = 0
|
||||
# Extend Single inputs
|
||||
if len(x_batch.shape) == 1:
|
||||
x_batch = [x_batch]
|
||||
remove_dims += 1
|
||||
if len(y_batch.shape) == 1:
|
||||
y_batch = [y_batch]
|
||||
remove_dims += 1
|
||||
|
||||
# Loop over batches
|
||||
output = []
|
||||
for x in x_batch:
|
||||
output.append([])
|
||||
for y in y_batch:
|
||||
output[-1].append(self.single_call(x, y))
|
||||
|
||||
output = torch.Tensor(output)
|
||||
for _ in range(remove_dims):
|
||||
output.squeeze_(0)
|
||||
|
||||
return output
|
||||
|
||||
def single_call(self, x, y):
|
||||
kappa_xx = self.kernel_fn(x, x)
|
||||
kappa_xy = self.kernel_fn(x, y)
|
||||
kappa_yy = self.kernel_fn(y, y)
|
||||
|
||||
return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy)
|
||||
squared_distance = kappa_xx - 2 * kappa_xy + kappa_yy
|
||||
|
||||
return torch.sqrt(squared_distance)
|
||||
|
||||
|
||||
class SquaredKernelDistance(KernelDistance):
|
||||
@ -282,7 +307,7 @@ class SquaredKernelDistance(KernelDistance):
|
||||
|
||||
Kernel distance without final squareroot.
|
||||
"""
|
||||
def __call__(self, x, y):
|
||||
def single_call(self, x, y):
|
||||
kappa_xx = self.kernel_fn(x, x)
|
||||
kappa_xy = self.kernel_fn(x, y)
|
||||
kappa_yy = self.kernel_fn(y, y)
|
||||
|
Loading…
Reference in New Issue
Block a user