Make prototype_labels non-trainable Parameters
This commit is contained in:
parent
0cfbc0473b
commit
8a4a596035
@ -162,4 +162,5 @@ class Prototypes1D(_Prototypes):
|
||||
|
||||
# Register module parameters
|
||||
self.prototypes = torch.nn.Parameter(prototypes)
|
||||
self.prototype_labels = prototype_labels
|
||||
self.prototype_labels = torch.nn.Parameter(
|
||||
prototype_labels.type(dtype)).requires_grad_(False)
|
||||
|
Loading…
Reference in New Issue
Block a user