All examples use argparse
This commit is contained in:
		@@ -1,10 +1,17 @@
 | 
				
			|||||||
"""CBC example using the Iris dataset."""
 | 
					"""CBC example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -30,18 +37,15 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    dvis = pt.models.VisCBC2D(data=train_ds,
 | 
					    vis = pt.models.VisCBC2D(data=train_ds,
 | 
				
			||||||
                             title="CBC Iris Example",
 | 
					                             title="CBC Iris Example",
 | 
				
			||||||
                             resolution=300,
 | 
					                             resolution=300,
 | 
				
			||||||
                             axis_off=True)
 | 
					                             axis_off=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        gpus=0,
 | 
					        args,
 | 
				
			||||||
        max_epochs=200,
 | 
					        callbacks=[vis],
 | 
				
			||||||
        callbacks=[
 | 
					 | 
				
			||||||
            dvis,
 | 
					 | 
				
			||||||
        ],
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,19 @@
 | 
				
			|||||||
"""GLVQ example using the Iris dataset."""
 | 
					"""GLVQ example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    from sklearn.datasets import load_iris
 | 
					 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					    x_train, y_train = load_iris(return_X_y=True)
 | 
				
			||||||
    x_train = x_train[:, [0, 2]]
 | 
					    x_train = x_train[:, [0, 2]]
 | 
				
			||||||
    train_ds = pt.datasets.NumpyDataset(x_train, y_train)
 | 
					    train_ds = pt.datasets.NumpyDataset(x_train, y_train)
 | 
				
			||||||
@@ -33,9 +40,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False)
 | 
					    vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        gpus=0,
 | 
					        args,
 | 
				
			||||||
        max_epochs=50,
 | 
					 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[vis],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,11 +1,17 @@
 | 
				
			|||||||
"""GLVQ example using the spiral dataset."""
 | 
					"""GLVQ example using the spiral dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.models.callbacks import StopOnNaN
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6)
 | 
					    train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -31,13 +37,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
 | 
					    vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
 | 
				
			||||||
    snan = StopOnNaN(model.proto_layer.components)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        gpus=0,
 | 
					        args,
 | 
				
			||||||
        max_epochs=200,
 | 
					        callbacks=[vis],
 | 
				
			||||||
        callbacks=[vis, snan],
 | 
					        terminate_on_nan=True,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,19 @@
 | 
				
			|||||||
"""GMLVQ example using all four dimensions of the Iris dataset."""
 | 
					"""GMLVQ example using all four dimensions of the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    from sklearn.datasets import load_iris
 | 
					 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					    x_train, y_train = load_iris(return_X_y=True)
 | 
				
			||||||
    train_ds = pt.datasets.NumpyDataset(x_train, y_train)
 | 
					    train_ds = pt.datasets.NumpyDataset(x_train, y_train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -30,7 +37,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
                            prototype_initializer=pt.components.SMI(train_ds))
 | 
					                            prototype_initializer=pt.components.SMI(train_ds))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(max_epochs=100, gpus=0)
 | 
					    trainer = pl.Trainer.from_argparse_args(args, )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -74,10 +74,6 @@ if __name__ == "__main__":
 | 
				
			|||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[vis],
 | 
				
			||||||
        # kwargs override the cli-arguments
 | 
					 | 
				
			||||||
        # max_epochs=50,
 | 
					 | 
				
			||||||
        # overfit_batches=1,
 | 
					 | 
				
			||||||
        # fast_dev_run=1,
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,19 @@
 | 
				
			|||||||
"""k-NN example using the Iris dataset."""
 | 
					"""k-NN example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    from sklearn.datasets import load_iris
 | 
					 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					    x_train, y_train = load_iris(return_X_y=True)
 | 
				
			||||||
    x_train = x_train[:, [0, 2]]
 | 
					    x_train = x_train[:, [0, 2]]
 | 
				
			||||||
    train_ds = pt.datasets.NumpyDataset(x_train, y_train)
 | 
					    train_ds = pt.datasets.NumpyDataset(x_train, y_train)
 | 
				
			||||||
@@ -26,7 +33,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = pt.models.VisGLVQ2D(data=(x_train, y_train), resolution=200)
 | 
					    vis = pt.models.VisGLVQ2D(data=(x_train, y_train), resolution=200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(max_epochs=1, callbacks=[vis], gpus=0)
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    # This is only for visualization. k-NN has no training phase.
 | 
					    # This is only for visualization. k-NN has no training phase.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,10 +1,17 @@
 | 
				
			|||||||
"""Limited Rank Matrix LVQ example using the Tecator dataset."""
 | 
					"""Limited Rank Matrix LVQ example using the Tecator dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    train_ds = pt.datasets.Tecator(root="~/datasets/", train=True)
 | 
					    train_ds = pt.datasets.Tecator(root="~/datasets/", train=True)
 | 
				
			||||||
    test_ds = pt.datasets.Tecator(root="~/datasets/", train=False)
 | 
					    test_ds = pt.datasets.Tecator(root="~/datasets/", train=False)
 | 
				
			||||||
@@ -40,11 +47,9 @@ if __name__ == "__main__":
 | 
				
			|||||||
                                    mode="min")
 | 
					                                    mode="min")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        gpus=0,
 | 
					        args,
 | 
				
			||||||
        max_epochs=100,
 | 
					 | 
				
			||||||
        callbacks=[vis, es],
 | 
					        callbacks=[vis, es],
 | 
				
			||||||
        weights_summary=None,
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,33 @@
 | 
				
			|||||||
"""LVQMLN example using all four dimensions of the Iris dataset."""
 | 
					"""LVQMLN example using all four dimensions of the Iris dataset."""
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from siamese_glvq_iris import Backbone
 | 
					
 | 
				
			||||||
 | 
					class Backbone(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.input_size = input_size
 | 
				
			||||||
 | 
					        self.hidden_size = hidden_size
 | 
				
			||||||
 | 
					        self.latent_size = latent_size
 | 
				
			||||||
 | 
					        self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
 | 
				
			||||||
 | 
					        self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
 | 
				
			||||||
 | 
					        self.activation = torch.nn.Sigmoid()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        x = self.activation(self.dense1(x))
 | 
				
			||||||
 | 
					        out = self.activation(self.dense2(x))
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    train_ds = pt.datasets.Iris()
 | 
					    train_ds = pt.datasets.Iris()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -48,7 +69,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(max_epochs=100, callbacks=[vis], gpus=0)
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,13 +1,20 @@
 | 
				
			|||||||
"""Neural Gas example using the Iris dataset."""
 | 
					"""Neural Gas example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
 | 
					from sklearn.preprocessing import StandardScaler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Prepare and pre-process the dataset
 | 
					    # Prepare and pre-process the dataset
 | 
				
			||||||
    from sklearn.datasets import load_iris
 | 
					 | 
				
			||||||
    from sklearn.preprocessing import StandardScaler
 | 
					 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					    x_train, y_train = load_iris(return_X_y=True)
 | 
				
			||||||
    x_train = x_train[:, [0, 2]]
 | 
					    x_train = x_train[:, [0, 2]]
 | 
				
			||||||
    scaler = StandardScaler()
 | 
					    scaler = StandardScaler()
 | 
				
			||||||
@@ -34,7 +41,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = pt.models.VisNG2D(data=train_ds)
 | 
					    vis = pt.models.VisNG2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(gpus=0, max_epochs=200, callbacks=[vis])
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,7 @@
 | 
				
			|||||||
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
 | 
					"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
@@ -22,6 +24,11 @@ class Backbone(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    train_ds = pt.datasets.Iris()
 | 
					    train_ds = pt.datasets.Iris()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -58,7 +65,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
 | 
					    vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(max_epochs=100, callbacks=[vis], gpus=0)
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user