From 41f0e77fc9dd1d56b55f839c4f3ac8b4120c3282 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 29 Mar 2022 17:08:40 +0200 Subject: [PATCH] fix: siameseGLVQ checks requires_grad of backbone Necessary for different optimizer runs --- prototorch/models/glvq.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index dc2858a..0074a78 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -147,9 +147,13 @@ class SiameseGLVQ(GLVQ): protos, _ = self.proto_layer() x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] latent_x = self.backbone(x) - self.backbone.requires_grad_(self.both_path_gradients) + + bb_grad = self.backbone._weights.requires_grad + + self.backbone.requires_grad_(bb_grad and self.both_path_gradients) latent_protos = self.backbone(protos) - self.backbone.requires_grad_(True) + self.backbone.requires_grad_(bb_grad) + distances = self.distance_layer(latent_x, latent_protos) return distances