Update SiameseGLVQ
This commit is contained in:
parent
96aeaa3448
commit
d8e017ae74
@ -107,8 +107,13 @@ class SiameseGLVQ(GLVQ):
|
||||
optim = self.hparams.optimizer
|
||||
proto_opt = optim(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
if list(self.backbone.parameters()):
|
||||
# only add an optimizer is the backbone has trainable parameters
|
||||
# otherwise, the next line fails
|
||||
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
||||
return proto_opt, bb_opt
|
||||
else:
|
||||
return proto_opt
|
||||
|
||||
def forward(self, x):
|
||||
self.sync_backbones()
|
||||
|
Loading…
Reference in New Issue
Block a user