diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 6fe76d3..60292b6 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -191,6 +191,9 @@ class GMLVQ(GLVQ): self.hparams.latent_dim, bias=False) + # Namespace hook for the visualization callbacks to work + self.backbone = self.omega_layer + @property def omega_matrix(self): return self.omega_layer.weight.detach().cpu()