Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
f35a08a070 | ||
|
3a5a2bb473 |
@ -73,6 +73,20 @@ def omega_distance(x, y, omega):
|
||||
return distances
|
||||
|
||||
|
||||
def ML_omega_distance(x, y, omegas, masks):
|
||||
"""Multi-Layer Omega distance."""
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
# omega = (omega_0 * mask_0) @ (omega_1 * mask_1)
|
||||
omegas = [torch.mul(_omega, _mask) for _omega, _mask in zip(omegas, masks)]
|
||||
omega = omegas[0] @ omegas[1]
|
||||
for _omega in omegas[2:]:
|
||||
omega = omega @ _omega
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
return distances
|
||||
|
||||
|
||||
def lomega_distance(x, y, omegas):
|
||||
r"""Localized Omega distance.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user