Make prototype_labels non-trainable Parameters
This commit is contained in:
		| @@ -162,4 +162,5 @@ class Prototypes1D(_Prototypes): | |||||||
|  |  | ||||||
|         # Register module parameters |         # Register module parameters | ||||||
|         self.prototypes = torch.nn.Parameter(prototypes) |         self.prototypes = torch.nn.Parameter(prototypes) | ||||||
|         self.prototype_labels = prototype_labels |         self.prototype_labels = torch.nn.Parameter( | ||||||
|  |             prototype_labels.type(dtype)).requires_grad_(False) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user