fix: siameseGLVQ checks requires_grad of backbone
Necessary for different optimizer runs
This commit is contained in:
parent
fab786a07e
commit
41f0e77fc9
@ -147,9 +147,13 @@ class SiameseGLVQ(GLVQ):
|
|||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
|
||||||
|
bb_grad = self.backbone._weights.requires_grad
|
||||||
|
|
||||||
|
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
self.backbone.requires_grad_(True)
|
self.backbone.requires_grad_(bb_grad)
|
||||||
|
|
||||||
distances = self.distance_layer(latent_x, latent_protos)
|
distances = self.distance_layer(latent_x, latent_protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user