Update NG

This commit is contained in:
Jensun Ravichandran 2021-06-07 18:35:08 +02:00
parent d558fa6a4a
commit b031382072
3 changed files with 7 additions and 6 deletions

View File

@ -17,11 +17,12 @@ if __name__ == "__main__":
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
num_prototypes=5,
input_dim=2,
lr=0.1,
)
@ -50,6 +51,3 @@ if __name__ == "__main__":
# Training loop
trainer.fit(model, train_loader)
# Model summary
print(model)

View File

@ -28,7 +28,11 @@ if __name__ == "__main__":
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(num_prototypes=30, lr=0.03)
hparams = dict(
num_prototypes=30,
input_dim=2,
lr=0.03,
)
# Initialize the model
model = pt.models.NeuralGas(

View File

@ -28,7 +28,6 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1)