diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index dee0a84..b713983 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -72,4 +72,4 @@ if __name__ == "__main__": # Training loop trainer.fit(model, train_loader) - torch.save(model, "iris.pth") \ No newline at end of file + torch.save(model, "iris.pth") diff --git a/examples/grlvq_iris.py b/examples/grlvq_iris.py index 2ede559..97e0a0c 100644 --- a/examples/grlvq_iris.py +++ b/examples/grlvq_iris.py @@ -71,4 +71,4 @@ if __name__ == "__main__": # Training loop trainer.fit(model, train_loader) - torch.save(model, "iris.pth") \ No newline at end of file + torch.save(model, "iris.pth") diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index d376c9a..2ca9834 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -274,7 +274,7 @@ class GMLVQ(GLVQ): omega = omega_initializer.generate(self.hparams["input_dim"], self.hparams["latent_dim"]) self.register_parameter("_omega", Parameter(omega)) - + @property def omega_matrix(self): return self._omega.detach().cpu()