Make prototype_labels non-trainable Parameters
This commit is contained in:
parent
0cfbc0473b
commit
8a4a596035
@ -162,4 +162,5 @@ class Prototypes1D(_Prototypes):
|
|||||||
|
|
||||||
# Register module parameters
|
# Register module parameters
|
||||||
self.prototypes = torch.nn.Parameter(prototypes)
|
self.prototypes = torch.nn.Parameter(prototypes)
|
||||||
self.prototype_labels = prototype_labels
|
self.prototype_labels = torch.nn.Parameter(
|
||||||
|
prototype_labels.type(dtype)).requires_grad_(False)
|
||||||
|
Loading…
Reference in New Issue
Block a user