diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 60292b6..308bfae 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -253,6 +253,9 @@ class LVQMLN(GLVQ): **kwargs): super().__init__(hparams, **kwargs) self.backbone = backbone_module(**backbone_params) + with torch.no_grad(): + protos = self.backbone(self.proto_layer()[0]) + self.proto_layer.load_state_dict({"_components": protos}, strict=False) def forward(self, x): latent_protos, _ = self.proto_layer()