From f35a08a070e12232969869ee93211aad5a5a0fbe Mon Sep 17 00:00:00 2001 From: julius Date: Tue, 7 Nov 2023 16:45:05 +0100 Subject: [PATCH] ML_omega_distance: allow for 2 or more omegas and masks --- prototorch/core/distances.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py index e86f498..33b4880 100644 --- a/prototorch/core/distances.py +++ b/prototorch/core/distances.py @@ -73,10 +73,14 @@ def omega_distance(x, y, omega): return distances -def ML_omega_distance(x, y, omega_0, omega_1, mask_0, mask_1): +def ML_omega_distance(x, y, omegas, masks): """Multi-Layer Omega distance.""" x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) - omega = (omega_0 * mask_0) @ (omega_1 * mask_1) + # omega = (omega_0 * mask_0) @ (omega_1 * mask_1) + omegas = [torch.mul(_omega, _mask) for _omega, _mask in zip(omegas, masks)] + omega = omegas[0] @ omegas[1] + for _omega in omegas[2:]: + omega = omega @ _omega projected_x = x @ omega projected_y = y @ omega distances = squared_euclidean_distance(projected_x, projected_y)