ML_omega_distance: allow for 2 or more omegas and masks
This commit is contained in:
		@@ -73,10 +73,14 @@ def omega_distance(x, y, omega):
 | 
				
			|||||||
    return distances
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ML_omega_distance(x, y, omega_0, omega_1, mask_0, mask_1):
 | 
					def ML_omega_distance(x, y, omegas, masks):
 | 
				
			||||||
    """Multi-Layer Omega distance."""
 | 
					    """Multi-Layer Omega distance."""
 | 
				
			||||||
    x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
 | 
					    x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
 | 
				
			||||||
    omega = (omega_0 * mask_0) @ (omega_1 * mask_1)
 | 
					    # omega = (omega_0 * mask_0) @ (omega_1 * mask_1)
 | 
				
			||||||
 | 
					    omegas = [torch.mul(_omega, _mask) for _omega, _mask in zip(omegas, masks)]
 | 
				
			||||||
 | 
					    omega = omegas[0] @ omegas[1]
 | 
				
			||||||
 | 
					    for _omega in omegas[2:]:
 | 
				
			||||||
 | 
					        omega = omega @ _omega
 | 
				
			||||||
    projected_x = x @ omega
 | 
					    projected_x = x @ omega
 | 
				
			||||||
    projected_y = y @ omega
 | 
					    projected_y = y @ omega
 | 
				
			||||||
    distances = squared_euclidean_distance(projected_x, projected_y)
 | 
					    distances = squared_euclidean_distance(projected_x, projected_y)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user