Compare commits

...

2 Commits

Author SHA1 Message Date
julius
f35a08a070
ML_omega_distance: allow for 2 or more omegas and masks 2023-11-07 16:45:05 +01:00
julius
3a5a2bb473
Implement a prototypical 2-layer Ω distance 2023-11-03 14:51:29 +01:00

View File

@ -73,6 +73,20 @@ def omega_distance(x, y, omega):
return distances return distances
def ML_omega_distance(x, y, omegas, masks):
"""Multi-Layer Omega distance."""
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
# omega = (omega_0 * mask_0) @ (omega_1 * mask_1)
omegas = [torch.mul(_omega, _mask) for _omega, _mask in zip(omegas, masks)]
omega = omegas[0] @ omegas[1]
for _omega in omegas[2:]:
omega = omega @ _omega
projected_x = x @ omega
projected_y = y @ omega
distances = squared_euclidean_distance(projected_x, projected_y)
return distances
def lomega_distance(x, y, omegas): def lomega_distance(x, y, omegas):
r"""Localized Omega distance. r"""Localized Omega distance.