Properly initialize prototypes in LVQMLN
This commit is contained in:
parent
7a86bb19a8
commit
c6e06ceaa4
@ -253,6 +253,9 @@ class LVQMLN(GLVQ):
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.backbone = backbone_module(**backbone_params)
|
self.backbone = backbone_module(**backbone_params)
|
||||||
|
with torch.no_grad():
|
||||||
|
protos = self.backbone(self.proto_layer()[0])
|
||||||
|
self.proto_layer.load_state_dict({"_components": protos}, strict=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
latent_protos, _ = self.proto_layer()
|
latent_protos, _ = self.proto_layer()
|
||||||
|
Loading…
Reference in New Issue
Block a user