Update SiameseGLVQ
This commit is contained in:
parent
96aeaa3448
commit
d8e017ae74
@ -107,8 +107,13 @@ class SiameseGLVQ(GLVQ):
|
|||||||
optim = self.hparams.optimizer
|
optim = self.hparams.optimizer
|
||||||
proto_opt = optim(self.proto_layer.parameters(),
|
proto_opt = optim(self.proto_layer.parameters(),
|
||||||
lr=self.hparams.proto_lr)
|
lr=self.hparams.proto_lr)
|
||||||
|
if list(self.backbone.parameters()):
|
||||||
|
# only add an optimizer is the backbone has trainable parameters
|
||||||
|
# otherwise, the next line fails
|
||||||
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
||||||
return proto_opt, bb_opt
|
return proto_opt, bb_opt
|
||||||
|
else:
|
||||||
|
return proto_opt
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
self.sync_backbones()
|
self.sync_backbones()
|
||||||
|
Loading…
Reference in New Issue
Block a user