Properly initialize prototypes in LVQMLN

This commit is contained in:
Jensun Ravichandran 2021-05-09 20:55:28 +02:00
parent 7a86bb19a8
commit c6e06ceaa4

View File

@ -253,6 +253,9 @@ class LVQMLN(GLVQ):
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
with torch.no_grad():
protos = self.backbone(self.proto_layer()[0])
self.proto_layer.load_state_dict({"_components": protos}, strict=False)
def forward(self, x):
latent_protos, _ = self.proto_layer()