Fix zero-distance bug in glvq_loss
This commit is contained in:
		@@ -12,12 +12,9 @@ def glvq_loss(distances, target_labels, prototype_labels):
 | 
				
			|||||||
        matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
 | 
					        matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
 | 
				
			||||||
    not_matcher = torch.bitwise_not(matcher)
 | 
					    not_matcher = torch.bitwise_not(matcher)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dplus_criterion = distances * matcher > 0.0
 | 
					 | 
				
			||||||
    dminus_criterion = distances * not_matcher > 0.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    inf = torch.full_like(distances, fill_value=float('inf'))
 | 
					    inf = torch.full_like(distances, fill_value=float('inf'))
 | 
				
			||||||
    distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
 | 
					    distances_to_wpluses = torch.where(matcher, distances, inf)
 | 
				
			||||||
    distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
 | 
					    distances_to_wminuses = torch.where(not_matcher, distances, inf)
 | 
				
			||||||
    dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
 | 
					    dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
 | 
				
			||||||
    dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
 | 
					    dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module):
 | 
				
			|||||||
        super().__init__(**kwargs)
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
        self.margin = margin
 | 
					        self.margin = margin
 | 
				
			||||||
        self.squashing = get_activation(squashing)
 | 
					        self.squashing = get_activation(squashing)
 | 
				
			||||||
        self.beta = beta
 | 
					        self.beta = torch.tensor(beta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, outputs, targets):
 | 
					    def forward(self, outputs, targets):
 | 
				
			||||||
        distances, plabels = outputs
 | 
					        distances, plabels = outputs
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user