Update NG
This commit is contained in:
		@@ -17,11 +17,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Prepare the data
 | 
					    # Prepare the data
 | 
				
			||||||
    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        num_prototypes=5,
 | 
					        num_prototypes=5,
 | 
				
			||||||
 | 
					        input_dim=2,
 | 
				
			||||||
        lr=0.1,
 | 
					        lr=0.1,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -50,6 +51,3 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Model summary
 | 
					 | 
				
			||||||
    print(model)
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,7 +28,11 @@ if __name__ == "__main__":
 | 
				
			|||||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(num_prototypes=30, lr=0.03)
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        num_prototypes=30,
 | 
				
			||||||
 | 
					        input_dim=2,
 | 
				
			||||||
 | 
					        lr=0.03,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.NeuralGas(
 | 
					    model = pt.models.NeuralGas(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,7 +28,6 @@ class NeuralGas(UnsupervisedPrototypeModel):
 | 
				
			|||||||
        self.save_hyperparameters(hparams)
 | 
					        self.save_hyperparameters(hparams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Default hparams
 | 
					        # Default hparams
 | 
				
			||||||
        self.hparams.setdefault("input_dim", 2)
 | 
					 | 
				
			||||||
        self.hparams.setdefault("agelimit", 10)
 | 
					        self.hparams.setdefault("agelimit", 10)
 | 
				
			||||||
        self.hparams.setdefault("lm", 1)
 | 
					        self.hparams.setdefault("lm", 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user