chore: fix minor errors and upgrade codebase

This commit is contained in:
Alexander Engelsberger
2023-06-20 16:06:53 +02:00
parent 6ed1b9a832
commit ee4cf583e3
6 changed files with 15 additions and 14 deletions

View File

@@ -11,7 +11,7 @@ def squared_euclidean_distance(x, y):
**Alias:**
``prototorch.functions.distances.sed``
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
expanded_x = x.unsqueeze(dim=1)
batchwise_difference = y - expanded_x
differences_raised = torch.pow(batchwise_difference, 2)
@@ -27,14 +27,14 @@ def euclidean_distance(x, y):
:returns: Distance Tensor of shape :math:`X \times Y`
:rtype: `torch.tensor`
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
distances_raised = squared_euclidean_distance(x, y)
distances = torch.sqrt(distances_raised)
return distances
def euclidean_distance_v2(x, y):
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
diff = y - x.unsqueeze(1)
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
@@ -54,7 +54,7 @@ def lpnorm_distance(x, y, p):
:param p: p parameter of the lp norm
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
distances = torch.cdist(x, y, p=p)
return distances
@@ -66,7 +66,7 @@ def omega_distance(x, y, omega):
:param `torch.tensor` omega: Two dimensional matrix
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
projected_x = x @ omega
projected_y = y @ omega
distances = squared_euclidean_distance(projected_x, projected_y)
@@ -80,7 +80,7 @@ def lomega_distance(x, y, omegas):
:param `torch.tensor` omegas: Three dimensional matrix
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
projected_x = x @ omegas
projected_y = torch.diagonal(y @ omegas).T
expanded_y = torch.unsqueeze(projected_y, dim=1)

View File

@@ -21,7 +21,7 @@ def cosine_similarity(x, y):
Expected dimension of x is 2.
Expected dimension of y is 2.
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
norm_x = x.pow(2).sum(1).sqrt()
norm_y = y.pow(2).sum(1).sqrt()
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T