diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py index b3278fd..e86f498 100644 --- a/prototorch/core/distances.py +++ b/prototorch/core/distances.py @@ -73,6 +73,16 @@ def omega_distance(x, y, omega): return distances +def ML_omega_distance(x, y, omega_0, omega_1, mask_0, mask_1): + """Multi-Layer Omega distance.""" + x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) + omega = (omega_0 * mask_0) @ (omega_1 * mask_1) + projected_x = x @ omega + projected_y = y @ omega + distances = squared_euclidean_distance(projected_x, projected_y) + return distances + + def lomega_distance(x, y, omegas): r"""Localized Omega distance.