From 3a5a2bb473fbbb97aca27b558d9df213c2f8f5cf Mon Sep 17 00:00:00 2001 From: julius Date: Fri, 3 Nov 2023 14:51:29 +0100 Subject: [PATCH] =?UTF-8?q?Implement=20a=20prototypical=202-layer=20=CE=A9?= =?UTF-8?q?=20distance?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- prototorch/core/distances.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py index b3278fd..e86f498 100644 --- a/prototorch/core/distances.py +++ b/prototorch/core/distances.py @@ -73,6 +73,16 @@ def omega_distance(x, y, omega): return distances +def ML_omega_distance(x, y, omega_0, omega_1, mask_0, mask_1): + """Multi-Layer Omega distance.""" + x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) + omega = (omega_0 * mask_0) @ (omega_1 * mask_1) + projected_x = x @ omega + projected_y = y @ omega + distances = squared_euclidean_distance(projected_x, projected_y) + return distances + + def lomega_distance(x, y, omegas): r"""Localized Omega distance.