From c6e06ceaa413dbd72c872c65870d70670490bc6b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 9 May 2021 20:55:28 +0200 Subject: [PATCH] Properly initialize prototypes in LVQMLN --- prototorch/models/glvq.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 60292b6..308bfae 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -253,6 +253,9 @@ class LVQMLN(GLVQ): **kwargs): super().__init__(hparams, **kwargs) self.backbone = backbone_module(**backbone_params) + with torch.no_grad(): + protos = self.backbone(self.proto_layer()[0]) + self.proto_layer.load_state_dict({"_components": protos}, strict=False) def forward(self, x): latent_protos, _ = self.proto_layer()