fix: siameseGLVQ checks requires_grad of backbone

Necessary for different optimizer runs
This commit is contained in:
Alexander Engelsberger 2022-03-29 17:08:40 +02:00
parent fab786a07e
commit 41f0e77fc9

View File

@ -147,9 +147,13 @@ class SiameseGLVQ(GLVQ):
protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
bb_grad = self.backbone._weights.requires_grad
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
self.backbone.requires_grad_(bb_grad)
distances = self.distance_layer(latent_x, latent_protos)
return distances