diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index dc2858a..0074a78 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -147,9 +147,13 @@ class SiameseGLVQ(GLVQ): protos, _ = self.proto_layer() x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] latent_x = self.backbone(x) - self.backbone.requires_grad_(self.both_path_gradients) + + bb_grad = self.backbone._weights.requires_grad + + self.backbone.requires_grad_(bb_grad and self.both_path_gradients) latent_protos = self.backbone(protos) - self.backbone.requires_grad_(True) + self.backbone.requires_grad_(bb_grad) + distances = self.distance_layer(latent_x, latent_protos) return distances