Update example scripts
This commit is contained in:
		@@ -17,15 +17,16 @@ if __name__ == "__main__":
 | 
				
			|||||||
                                               batch_size=150)
 | 
					                                               batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    nclasses = 3
 | 
				
			||||||
 | 
					    prototypes_per_class = 2
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        nclasses=3,
 | 
					        distribution=(nclasses, prototypes_per_class),
 | 
				
			||||||
        prototypes_per_class=2,
 | 
					 | 
				
			||||||
        prototype_initializer=pt.components.SMI(train_ds),
 | 
					        prototype_initializer=pt.components.SMI(train_ds),
 | 
				
			||||||
        lr=0.01,
 | 
					        lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.GLVQ(hparams)
 | 
					    model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
 | 
					    vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,10 +25,11 @@ if __name__ == "__main__":
 | 
				
			|||||||
                                               batch_size=256)
 | 
					                                               batch_size=256)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    nclasses = 2
 | 
				
			||||||
 | 
					    prototypes_per_class = 20
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        nclasses=2,
 | 
					        distribution=(nclasses, prototypes_per_class),
 | 
				
			||||||
        prototypes_per_class=20,
 | 
					        prototype_initializer=pt.components.SSI(train_ds, noise=1e-1),
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, noise=1e-7),
 | 
					 | 
				
			||||||
        transfer_function="sigmoid_beta",
 | 
					        transfer_function="sigmoid_beta",
 | 
				
			||||||
        transfer_beta=10.0,
 | 
					        transfer_beta=10.0,
 | 
				
			||||||
        lr=0.01,
 | 
					        lr=0.01,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,9 +15,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
                                               num_workers=0,
 | 
					                                               num_workers=0,
 | 
				
			||||||
                                               batch_size=150)
 | 
					                                               batch_size=150)
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    nclasses = 3
 | 
				
			||||||
 | 
					    prototypes_per_class = 1
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        nclasses=3,
 | 
					        distribution=(nclasses, prototypes_per_class),
 | 
				
			||||||
        prototypes_per_class=1,
 | 
					 | 
				
			||||||
        input_dim=x_train.shape[1],
 | 
					        input_dim=x_train.shape[1],
 | 
				
			||||||
        latent_dim=x_train.shape[1],
 | 
					        latent_dim=x_train.shape[1],
 | 
				
			||||||
        prototype_initializer=pt.components.SMI(train_ds),
 | 
					        prototype_initializer=pt.components.SMI(train_ds),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,9 +17,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
                                               batch_size=32)
 | 
					                                               batch_size=32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    nclasses = 2
 | 
				
			||||||
 | 
					    prototypes_per_class = 2
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        nclasses=2,
 | 
					        distribution=(nclasses, prototypes_per_class),
 | 
				
			||||||
        prototypes_per_class=2,
 | 
					 | 
				
			||||||
        input_dim=100,
 | 
					        input_dim=100,
 | 
				
			||||||
        latent_dim=2,
 | 
					        latent_dim=2,
 | 
				
			||||||
        prototype_initializer=pt.components.SMI(train_ds),
 | 
					        prototype_initializer=pt.components.SMI(train_ds),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,42 +0,0 @@
 | 
				
			|||||||
"""Classical LVQ using GLVQ example on the Iris dataset."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    # Dataset
 | 
					 | 
				
			||||||
    from sklearn.datasets import load_iris
 | 
					 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					 | 
				
			||||||
    x_train = x_train[:, [0, 2]]
 | 
					 | 
				
			||||||
    train_ds = pt.datasets.NumpyDataset(x_train, y_train)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Dataloaders
 | 
					 | 
				
			||||||
    train_loader = torch.utils.data.DataLoader(train_ds,
 | 
					 | 
				
			||||||
                                               num_workers=0,
 | 
					 | 
				
			||||||
                                               batch_size=150)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Hyperparameters
 | 
					 | 
				
			||||||
    hparams = dict(
 | 
					 | 
				
			||||||
        nclasses=3,
 | 
					 | 
				
			||||||
        prototypes_per_class=2,
 | 
					 | 
				
			||||||
        prototype_initializer=pt.components.SMI(train_ds),
 | 
					 | 
				
			||||||
        #prototype_initializer=pt.components.Random(2),
 | 
					 | 
				
			||||||
        lr=0.005,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Initialize the model
 | 
					 | 
				
			||||||
    model = pt.models.LVQ1(hparams)
 | 
					 | 
				
			||||||
    #model = pt.models.LVQ21(hparams)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Callbacks
 | 
					 | 
				
			||||||
    vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Setup trainer
 | 
					 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					 | 
				
			||||||
        max_epochs=200,
 | 
					 | 
				
			||||||
        callbacks=[vis],
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Training loop
 | 
					 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					 | 
				
			||||||
@@ -38,11 +38,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        nclasses=3,
 | 
					        distribution=[1, 2, 3],
 | 
				
			||||||
        prototypes_per_class=2,
 | 
					 | 
				
			||||||
        prototype_initializer=pt.components.SMI((x_train, y_train)),
 | 
					        prototype_initializer=pt.components.SMI((x_train, y_train)),
 | 
				
			||||||
        proto_lr=0.001,
 | 
					        proto_lr=0.01,
 | 
				
			||||||
        bb_lr=0.001,
 | 
					        bb_lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user