fix: siameseGLVQ checks requires_grad of backbone
Necessary for different optimizer runs
This commit is contained in:
parent
fab786a07e
commit
41f0e77fc9
@ -147,9 +147,13 @@ class SiameseGLVQ(GLVQ):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
|
||||
bb_grad = self.backbone._weights.requires_grad
|
||||
|
||||
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
self.backbone.requires_grad_(True)
|
||||
self.backbone.requires_grad_(bb_grad)
|
||||
|
||||
distances = self.distance_layer(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user