Make prototype_labels non-trainable Parameters

This commit is contained in:
blackfly 2020-04-27 13:39:27 +02:00
parent 0cfbc0473b
commit 8a4a596035

View File

@ -162,4 +162,5 @@ class Prototypes1D(_Prototypes):
# Register module parameters
self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = prototype_labels
self.prototype_labels = torch.nn.Parameter(
prototype_labels.type(dtype)).requires_grad_(False)