Batch Kernel. [Ineficient]
This commit is contained in:
parent
7d353f5b5a
commit
b0cd2de18e
@ -269,12 +269,37 @@ class KernelDistance:
|
|||||||
def __init__(self, kernel_fn):
|
def __init__(self, kernel_fn):
|
||||||
self.kernel_fn = kernel_fn
|
self.kernel_fn = kernel_fn
|
||||||
|
|
||||||
def __call__(self, x, y):
|
def __call__(self, x_batch, y_batch):
|
||||||
|
remove_dims = 0
|
||||||
|
# Extend Single inputs
|
||||||
|
if len(x_batch.shape) == 1:
|
||||||
|
x_batch = [x_batch]
|
||||||
|
remove_dims += 1
|
||||||
|
if len(y_batch.shape) == 1:
|
||||||
|
y_batch = [y_batch]
|
||||||
|
remove_dims += 1
|
||||||
|
|
||||||
|
# Loop over batches
|
||||||
|
output = []
|
||||||
|
for x in x_batch:
|
||||||
|
output.append([])
|
||||||
|
for y in y_batch:
|
||||||
|
output[-1].append(self.single_call(x, y))
|
||||||
|
|
||||||
|
output = torch.Tensor(output)
|
||||||
|
for _ in range(remove_dims):
|
||||||
|
output.squeeze_(0)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def single_call(self, x, y):
|
||||||
kappa_xx = self.kernel_fn(x, x)
|
kappa_xx = self.kernel_fn(x, x)
|
||||||
kappa_xy = self.kernel_fn(x, y)
|
kappa_xy = self.kernel_fn(x, y)
|
||||||
kappa_yy = self.kernel_fn(y, y)
|
kappa_yy = self.kernel_fn(y, y)
|
||||||
|
|
||||||
return torch.sqrt(kappa_xx - 2 * kappa_xy + kappa_yy)
|
squared_distance = kappa_xx - 2 * kappa_xy + kappa_yy
|
||||||
|
|
||||||
|
return torch.sqrt(squared_distance)
|
||||||
|
|
||||||
|
|
||||||
class SquaredKernelDistance(KernelDistance):
|
class SquaredKernelDistance(KernelDistance):
|
||||||
@ -282,7 +307,7 @@ class SquaredKernelDistance(KernelDistance):
|
|||||||
|
|
||||||
Kernel distance without final squareroot.
|
Kernel distance without final squareroot.
|
||||||
"""
|
"""
|
||||||
def __call__(self, x, y):
|
def single_call(self, x, y):
|
||||||
kappa_xx = self.kernel_fn(x, x)
|
kappa_xx = self.kernel_fn(x, x)
|
||||||
kappa_xy = self.kernel_fn(x, y)
|
kappa_xy = self.kernel_fn(x, y)
|
||||||
kappa_yy = self.kernel_fn(y, y)
|
kappa_yy = self.kernel_fn(y, y)
|
||||||
|
Loading…
Reference in New Issue
Block a user