Update example scripts
This commit is contained in:
		@@ -30,7 +30,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    prototypes_per_class = num_clusters * 5
 | 
					    prototypes_per_class = num_clusters * 5
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution=(num_classes, prototypes_per_class),
 | 
					        distribution=(num_classes, prototypes_per_class),
 | 
				
			||||||
        lr=0.1,
 | 
					        lr=0.2,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
@@ -39,6 +39,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
        prototype_initializer=pt.components.Ones(2, scale=3),
 | 
					        prototype_initializer=pt.components.Ones(2, scale=3),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Summary
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisGLVQ2D(train_ds)
 | 
					    vis = pt.models.VisGLVQ2D(train_ds)
 | 
				
			||||||
    pruning = pt.models.PruneLoserPrototypes(
 | 
					    pruning = pt.models.PruneLoserPrototypes(
 | 
				
			||||||
@@ -67,7 +73,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
        ],
 | 
					        ],
 | 
				
			||||||
        progress_bar_refresh_rate=0,
 | 
					        progress_bar_refresh_rate=0,
 | 
				
			||||||
        terminate_on_nan=True,
 | 
					        terminate_on_nan=True,
 | 
				
			||||||
        weights_summary=None,
 | 
					        weights_summary="full",
 | 
				
			||||||
        accelerator="ddp",
 | 
					        accelerator="ddp",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -14,35 +13,63 @@ if __name__ == "__main__":
 | 
				
			|||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    train_ds = pt.datasets.Spiral(num_samples=600, noise=0.6)
 | 
					    train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataloaders
 | 
					    # Dataloaders
 | 
				
			||||||
    train_loader = torch.utils.data.DataLoader(train_ds,
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
 | 
				
			||||||
                                               num_workers=0,
 | 
					 | 
				
			||||||
                                               batch_size=256)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    num_classes = 2
 | 
					    num_classes = 2
 | 
				
			||||||
    prototypes_per_class = 20
 | 
					    prototypes_per_class = 10
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution=(num_classes, prototypes_per_class),
 | 
					        distribution=(num_classes, prototypes_per_class),
 | 
				
			||||||
        transfer_function="sigmoid_beta",
 | 
					        transfer_function="swish_beta",
 | 
				
			||||||
        transfer_beta=10.0,
 | 
					        transfer_beta=10.0,
 | 
				
			||||||
        lr=0.01,
 | 
					        # lr=0.1,
 | 
				
			||||||
 | 
					        proto_lr=0.1,
 | 
				
			||||||
 | 
					        bb_lr=0.1,
 | 
				
			||||||
 | 
					        input_dim=2,
 | 
				
			||||||
 | 
					        latent_dim=2,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.GLVQ(hparams,
 | 
					    model = pt.models.GMLVQ(
 | 
				
			||||||
                           prototype_initializer=pt.components.SSI(train_ds,
 | 
					        hparams,
 | 
				
			||||||
                                                                   noise=1e-1))
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
 | 
					        prototype_initializer=pt.components.SSI(train_ds, noise=1e-2),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
 | 
					    vis = pt.models.VisGLVQ2D(
 | 
				
			||||||
 | 
					        train_ds,
 | 
				
			||||||
 | 
					        show_last_only=False,
 | 
				
			||||||
 | 
					        block=False,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    pruning = pt.models.PruneLoserPrototypes(
 | 
				
			||||||
 | 
					        threshold=0.02,
 | 
				
			||||||
 | 
					        idle_epochs=10,
 | 
				
			||||||
 | 
					        prune_quota_per_epoch=5,
 | 
				
			||||||
 | 
					        frequency=2,
 | 
				
			||||||
 | 
					        replace=True,
 | 
				
			||||||
 | 
					        initializer=pt.components.SSI(train_ds, noise=1e-2),
 | 
				
			||||||
 | 
					        verbose=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    es = pl.callbacks.EarlyStopping(
 | 
				
			||||||
 | 
					        monitor="train_loss",
 | 
				
			||||||
 | 
					        min_delta=1.0,
 | 
				
			||||||
 | 
					        patience=5,
 | 
				
			||||||
 | 
					        mode="min",
 | 
				
			||||||
 | 
					        check_on_train_epoch_end=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					            # es,
 | 
				
			||||||
 | 
					            pruning,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
        terminate_on_nan=True,
 | 
					        terminate_on_nan=True,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -43,7 +43,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    num_classes = 10
 | 
					    num_classes = 10
 | 
				
			||||||
    prototypes_per_class = 2
 | 
					    prototypes_per_class = 10
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        input_dim=28 * 28,
 | 
					        input_dim=28 * 28,
 | 
				
			||||||
        latent_dim=28 * 28,
 | 
					        latent_dim=28 * 28,
 | 
				
			||||||
@@ -62,19 +62,40 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisImgComp(
 | 
					    vis = pt.models.VisImgComp(
 | 
				
			||||||
        data=train_ds,
 | 
					        data=train_ds,
 | 
				
			||||||
        num_columns=5,
 | 
					        num_columns=10,
 | 
				
			||||||
        show=False,
 | 
					        show=False,
 | 
				
			||||||
        tensorboard=True,
 | 
					        tensorboard=True,
 | 
				
			||||||
        random_data=20,
 | 
					        random_data=100,
 | 
				
			||||||
        add_embedding=True,
 | 
					        add_embedding=True,
 | 
				
			||||||
        embedding_data=100,
 | 
					        embedding_data=200,
 | 
				
			||||||
        flatten_data=False,
 | 
					        flatten_data=False,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    pruning = pt.models.PruneLoserPrototypes(
 | 
				
			||||||
 | 
					        threshold=0.01,
 | 
				
			||||||
 | 
					        idle_epochs=1,
 | 
				
			||||||
 | 
					        prune_quota_per_epoch=10,
 | 
				
			||||||
 | 
					        frequency=1,
 | 
				
			||||||
 | 
					        verbose=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    es = pl.callbacks.EarlyStopping(
 | 
				
			||||||
 | 
					        monitor="train_loss",
 | 
				
			||||||
 | 
					        min_delta=0.001,
 | 
				
			||||||
 | 
					        patience=15,
 | 
				
			||||||
 | 
					        mode="min",
 | 
				
			||||||
 | 
					        check_on_train_epoch_end=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					            pruning,
 | 
				
			||||||
 | 
					            # es,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        terminate_on_nan=True,
 | 
				
			||||||
 | 
					        weights_summary=None,
 | 
				
			||||||
 | 
					        accelerator="ddp",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,10 +4,7 @@ import argparse
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
from prototorch.components.initializers import Zeros
 | 
					import torch
 | 
				
			||||||
from prototorch.datasets import Iris
 | 
					 | 
				
			||||||
from prototorch.models.unsupervised import GrowingNeuralGas
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
@@ -19,8 +16,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    pl.utilities.seed.seed_everything(seed=42)
 | 
					    pl.utilities.seed.seed_everything(seed=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Prepare the data
 | 
					    # Prepare the data
 | 
				
			||||||
    train_ds = Iris(dims=[0, 2])
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
    train_loader = DataLoader(train_ds, batch_size=8)
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
@@ -29,11 +26,14 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = GrowingNeuralGas(
 | 
					    model = pt.models.GrowingNeuralGas(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=Zeros(2),
 | 
					        prototype_initializer=pt.components.Zeros(2),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Model summary
 | 
					    # Model summary
 | 
				
			||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -45,6 +45,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
        args,
 | 
					        args,
 | 
				
			||||||
        max_epochs=100,
 | 
					        max_epochs=100,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,13 +1,12 @@
 | 
				
			|||||||
"""k-NN example using the Iris dataset."""
 | 
					"""k-NN example using the Iris dataset from scikit-learn."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -23,18 +22,30 @@ if __name__ == "__main__":
 | 
				
			|||||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(k=20)
 | 
					    hparams = dict(k=5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.KNN(hparams, data=train_ds)
 | 
					    model = pt.models.KNN(hparams, data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Summary
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisGLVQ2D(data=(x_train, y_train), resolution=200)
 | 
					    vis = pt.models.VisGLVQ2D(
 | 
				
			||||||
 | 
					        data=(x_train, y_train),
 | 
				
			||||||
 | 
					        resolution=200,
 | 
				
			||||||
 | 
					        block=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
 | 
					        max_epochs=1,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,6 +7,7 @@ import pytorch_lightning as pl
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from sklearn.preprocessing import StandardScaler
 | 
					from sklearn.preprocessing import StandardScaler
 | 
				
			||||||
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
@@ -30,8 +31,15 @@ if __name__ == "__main__":
 | 
				
			|||||||
    hparams = dict(num_prototypes=30, lr=0.03)
 | 
					    hparams = dict(num_prototypes=30, lr=0.03)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.NeuralGas(hparams,
 | 
					    model = pt.models.NeuralGas(
 | 
				
			||||||
                                prototype_initializer=pt.components.Zeros(2))
 | 
					        hparams,
 | 
				
			||||||
 | 
					        prototype_initializer=pt.components.Zeros(2),
 | 
				
			||||||
 | 
					        lr_scheduler=ExponentialLR,
 | 
				
			||||||
 | 
					        lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Model summary
 | 
					    # Model summary
 | 
				
			||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
@@ -43,6 +51,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,11 @@
 | 
				
			|||||||
"""Probabilistic-LVQ example using the Iris dataset."""
 | 
					"""RSLVQ example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -26,16 +25,23 @@ if __name__ == "__main__":
 | 
				
			|||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution=[2, 2, 3],
 | 
					        distribution=[2, 2, 3],
 | 
				
			||||||
        lr=0.05,
 | 
					        lr=0.05,
 | 
				
			||||||
        variance=0.3,
 | 
					        variance=0.1,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.probabilistic.RSLVQ(
 | 
					    model = pt.models.probabilistic.RSLVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        optimizer=torch.optim.Adam,
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
 | 
					        # prototype_initializer=pt.components.SMI(train_ds),
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
 | 
					        prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
 | 
				
			||||||
 | 
					        # prototype_initializer=pt.components.Zeros(2),
 | 
				
			||||||
 | 
					        # prototype_initializer=pt.components.Ones(2, scale=2.0),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Summary
 | 
				
			||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
@@ -46,8 +52,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[vis],
 | 
				
			||||||
        terminate_on_nan=True,
 | 
					        terminate_on_nan=True,
 | 
				
			||||||
        weights_summary=None,
 | 
					        weights_summary="full",
 | 
				
			||||||
        # accelerator="ddp",
 | 
					        accelerator="ddp",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user