chore: fix minor errors and upgrade codebase
This commit is contained in:
@@ -11,7 +11,7 @@ def squared_euclidean_distance(x, y):
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
@@ -27,14 +27,14 @@ def euclidean_distance(x, y):
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
@@ -54,7 +54,7 @@ def lpnorm_distance(x, y, p):
|
||||
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
@@ -66,7 +66,7 @@ def omega_distance(x, y, omega):
|
||||
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
@@ -80,7 +80,7 @@ def lomega_distance(x, y, omegas):
|
||||
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
|
@@ -21,7 +21,7 @@ def cosine_similarity(x, y):
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
norm_x = x.pow(2).sum(1).sqrt()
|
||||
norm_y = y.pow(2).sum(1).sqrt()
|
||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||
|
Reference in New Issue
Block a user