Update NG

This commit is contained in:
Jensun Ravichandran 2021-06-07 18:35:08 +02:00
parent d558fa6a4a
commit b031382072
3 changed files with 7 additions and 6 deletions

View File

@ -17,11 +17,12 @@ if __name__ == "__main__":
# Prepare the data # Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
num_prototypes=5, num_prototypes=5,
input_dim=2,
lr=0.1, lr=0.1,
) )
@ -50,6 +51,3 @@ if __name__ == "__main__":
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Model summary
print(model)

View File

@ -28,7 +28,11 @@ if __name__ == "__main__":
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict(num_prototypes=30, lr=0.03) hparams = dict(
num_prototypes=30,
input_dim=2,
lr=0.03,
)
# Initialize the model # Initialize the model
model = pt.models.NeuralGas( model = pt.models.NeuralGas(

View File

@ -28,7 +28,6 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
# Default hparams # Default hparams
self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1) self.hparams.setdefault("lm", 1)