fix: siameseGLVQ checks requires_grad of backbone
Necessary for different optimizer runs
This commit is contained in:
		@@ -147,9 +147,13 @@ class SiameseGLVQ(GLVQ):
 | 
				
			|||||||
        protos, _ = self.proto_layer()
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
        x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
 | 
					        x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        self.backbone.requires_grad_(self.both_path_gradients)
 | 
					
 | 
				
			||||||
 | 
					        bb_grad = self.backbone._weights.requires_grad
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
 | 
				
			||||||
        latent_protos = self.backbone(protos)
 | 
					        latent_protos = self.backbone(protos)
 | 
				
			||||||
        self.backbone.requires_grad_(True)
 | 
					        self.backbone.requires_grad_(bb_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        distances = self.distance_layer(latent_x, latent_protos)
 | 
					        distances = self.distance_layer(latent_x, latent_protos)
 | 
				
			||||||
        return distances
 | 
					        return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user