fix: style fixes
This commit is contained in:
parent
46dfb82371
commit
16f410e809
@ -72,4 +72,4 @@ if __name__ == "__main__":
|
|||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
torch.save(model, "iris.pth")
|
torch.save(model, "iris.pth")
|
||||||
|
@ -71,4 +71,4 @@ if __name__ == "__main__":
|
|||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
torch.save(model, "iris.pth")
|
torch.save(model, "iris.pth")
|
||||||
|
@ -274,7 +274,7 @@ class GMLVQ(GLVQ):
|
|||||||
omega = omega_initializer.generate(self.hparams["input_dim"],
|
omega = omega_initializer.generate(self.hparams["input_dim"],
|
||||||
self.hparams["latent_dim"])
|
self.hparams["latent_dim"])
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
||||||
return self._omega.detach().cpu()
|
return self._omega.detach().cpu()
|
||||||
|
Loading…
Reference in New Issue
Block a user