Update SiameseGLVQ

This commit is contained in:
Jensun Ravichandran 2021-05-03 16:09:22 +02:00
parent 96aeaa3448
commit d8e017ae74

View File

@ -107,8 +107,13 @@ class SiameseGLVQ(GLVQ):
optim = self.hparams.optimizer optim = self.hparams.optimizer
proto_opt = optim(self.proto_layer.parameters(), proto_opt = optim(self.proto_layer.parameters(),
lr=self.hparams.proto_lr) lr=self.hparams.proto_lr)
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr) if list(self.backbone.parameters()):
return proto_opt, bb_opt # only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
return proto_opt, bb_opt
else:
return proto_opt
def forward(self, x): def forward(self, x):
self.sync_backbones() self.sync_backbones()