refactor: refactor GLVQLoss
				
					
				
			This commit is contained in:
		| @@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3): | |||||||
|  |  | ||||||
|  |  | ||||||
| class GLVQLoss(torch.nn.Module): | class GLVQLoss(torch.nn.Module): | ||||||
|     def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): |     def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs): | ||||||
|         super().__init__(**kwargs) |         super().__init__(**kwargs) | ||||||
|         self.margin = margin |         self.margin = margin | ||||||
|         self.squashing = get_activation(squashing) |         self.transfer_fn = get_activation(transfer_fn) | ||||||
|         self.beta = torch.tensor(beta) |         self.beta = torch.tensor(beta) | ||||||
|  |  | ||||||
|     def forward(self, outputs, targets): |     def forward(self, outputs, targets, plabels): | ||||||
|         distances, plabels = outputs |         mu = glvq_loss(outputs, targets, prototype_labels=plabels) | ||||||
|         mu = glvq_loss(distances, targets, prototype_labels=plabels) |         batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta) | ||||||
|         batch_loss = self.squashing(mu + self.margin, beta=self.beta) |         return batch_loss.sum() | ||||||
|         return torch.sum(batch_loss, dim=0) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class MarginLoss(torch.nn.modules.loss._Loss): | class MarginLoss(torch.nn.modules.loss._Loss): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user