fix: remove removed CLI syntax from examples
This commit is contained in:
		@@ -5,8 +5,8 @@ import warnings
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import CBC, VisCBC2D
 | 
					from prototorch.models import CBC, VisCBC2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,7 +19,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -53,8 +54,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,13 +7,13 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import (
 | 
					from prototorch.models import (
 | 
				
			||||||
    CELVQ,
 | 
					    CELVQ,
 | 
				
			||||||
    PruneLoserPrototypes,
 | 
					    PruneLoserPrototypes,
 | 
				
			||||||
    VisGLVQ2D,
 | 
					    VisGLVQ2D,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -26,7 +26,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -83,8 +84,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            pruning,
 | 
					            pruning,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,8 +7,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import GLVQ, VisGLVQ2D
 | 
					from prototorch.models import GLVQ, VisGLVQ2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
@@ -21,7 +21,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    seed_everything(seed=4)
 | 
					    seed_everything(seed=4)
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -55,8 +56,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisGLVQ2D(data=train_ds)
 | 
					    vis = VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import GMLVQ, VisGMLVQ2D
 | 
					from prototorch.models import GMLVQ, VisGMLVQ2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
@@ -22,7 +22,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -59,8 +60,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisGMLVQ2D(data=train_ds)
 | 
					    vis = VisGMLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,13 +6,13 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import (
 | 
					from prototorch.models import (
 | 
				
			||||||
    ImageGMLVQ,
 | 
					    ImageGMLVQ,
 | 
				
			||||||
    PruneLoserPrototypes,
 | 
					    PruneLoserPrototypes,
 | 
				
			||||||
    VisImgComp,
 | 
					    VisImgComp,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
@@ -26,7 +26,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    seed_everything(seed=4)
 | 
					    seed_everything(seed=4)
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -96,8 +97,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            pruning,
 | 
					            pruning,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,13 +6,13 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import (
 | 
					from prototorch.models import (
 | 
				
			||||||
    GMLVQ,
 | 
					    GMLVQ,
 | 
				
			||||||
    PruneLoserPrototypes,
 | 
					    PruneLoserPrototypes,
 | 
				
			||||||
    VisGLVQ2D,
 | 
					    VisGLVQ2D,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -25,7 +25,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -78,8 +79,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            es,
 | 
					            es,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,8 +7,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import GrowingNeuralGas, VisNG2D
 | 
					from prototorch.models import GrowingNeuralGas, VisNG2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Reproducibility
 | 
					    # Reproducibility
 | 
				
			||||||
@@ -51,8 +52,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisNG2D(data=train_loader)
 | 
					    vis = VisNG2D(data=train_loader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
 | 
					from prototorch.models import GRLVQ, VisSiameseGLVQ2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
@@ -22,7 +22,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -58,8 +59,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisSiameseGLVQ2D(data=train_ds)
 | 
					    vis = VisSiameseGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,13 +6,13 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import (
 | 
					from prototorch.models import (
 | 
				
			||||||
    ImageGTLVQ,
 | 
					    ImageGTLVQ,
 | 
				
			||||||
    PruneLoserPrototypes,
 | 
					    PruneLoserPrototypes,
 | 
				
			||||||
    VisImgComp,
 | 
					    VisImgComp,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
@@ -27,7 +27,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -100,8 +101,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    # using GPUs here is strongly recommended!
 | 
					    # using GPUs here is strongly recommended!
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            pruning,
 | 
					            pruning,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,9 +7,9 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import GTLVQ, VisGLVQ2D
 | 
					from prototorch.models import GTLVQ, VisGLVQ2D
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Reproducibility
 | 
					    # Reproducibility
 | 
				
			||||||
@@ -61,8 +62,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            es,
 | 
					            es,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning)
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -59,8 +60,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        max_epochs=1,
 | 
					        max_epochs=1,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,10 +7,10 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from prototorch.models import KohonenSOM
 | 
					from prototorch.models import KohonenSOM
 | 
				
			||||||
from prototorch.utils.colors import hex_to_rgb
 | 
					from prototorch.utils.colors import hex_to_rgb
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader, TensorDataset
 | 
					from torch.utils.data import DataLoader, TensorDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback):
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Reproducibility
 | 
					    # Reproducibility
 | 
				
			||||||
@@ -104,8 +105,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = Vis2DColorSOM(data=data)
 | 
					    vis = Vis2DColorSOM(data=data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        max_epochs=500,
 | 
					        max_epochs=500,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,9 +7,9 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import LGMLVQ, VisGLVQ2D
 | 
					from prototorch.models import LGMLVQ, VisGLVQ2D
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Reproducibility
 | 
					    # Reproducibility
 | 
				
			||||||
@@ -62,8 +63,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            es,
 | 
					            es,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,12 +6,12 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import (
 | 
					from prototorch.models import (
 | 
				
			||||||
    LVQMLN,
 | 
					    LVQMLN,
 | 
				
			||||||
    PruneLoserPrototypes,
 | 
					    PruneLoserPrototypes,
 | 
				
			||||||
    VisSiameseGLVQ2D,
 | 
					    VisSiameseGLVQ2D,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -39,7 +39,8 @@ class Backbone(torch.nn.Module):
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -88,8 +89,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            pruning,
 | 
					            pruning,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,9 +6,9 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import MedianLVQ, VisGLVQ2D
 | 
					from prototorch.models import MedianLVQ, VisGLVQ2D
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -20,7 +20,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    seed_everything(seed=4)
 | 
					    seed_everything(seed=4)
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -53,8 +54,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            es,
 | 
					            es,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import NeuralGas, VisNG2D
 | 
					from prototorch.models import NeuralGas, VisNG2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from sklearn.preprocessing import StandardScaler
 | 
					from sklearn.preprocessing import StandardScaler
 | 
				
			||||||
@@ -23,7 +23,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Prepare and pre-process the dataset
 | 
					    # Prepare and pre-process the dataset
 | 
				
			||||||
@@ -60,8 +61,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisNG2D(data=train_ds)
 | 
					    vis = VisNG2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import RSLVQ, VisGLVQ2D
 | 
					from prototorch.models import RSLVQ, VisGLVQ2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Reproducibility
 | 
					    # Reproducibility
 | 
				
			||||||
@@ -54,8 +55,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisGLVQ2D(data=train_ds)
 | 
					    vis = VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
 | 
					from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -69,8 +70,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
 | 
					    vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,8 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
 | 
					from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
 | 
				
			|||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
@@ -71,8 +72,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
 | 
					    vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ import warnings
 | 
				
			|||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from lightning_fabric.utilities.seed import seed_everything
 | 
				
			||||||
from prototorch.models import (
 | 
					from prototorch.models import (
 | 
				
			||||||
    GLVQ,
 | 
					    GLVQ,
 | 
				
			||||||
    KNN,
 | 
					    KNN,
 | 
				
			||||||
@@ -14,7 +15,6 @@ from prototorch.models import (
 | 
				
			|||||||
    VisGLVQ2D,
 | 
					    VisGLVQ2D,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from pytorch_lightning.utilities.seed import seed_everything
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
					from pytorch_lightning.utilities.warnings import PossibleUserWarning
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
@@ -27,7 +27,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    seed_everything(seed=4)
 | 
					    seed_everything(seed=4)
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					    parser.add_argument("--gpus", type=int, default=0)
 | 
				
			||||||
 | 
					    parser.add_argument("--fast_dev_run", type=bool, default=False)
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Prepare the data
 | 
					    # Prepare the data
 | 
				
			||||||
@@ -108,8 +109,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        args,
 | 
					        accelerator="cuda" if args.gpus else "cpu",
 | 
				
			||||||
 | 
					        devices=args.gpus if args.gpus else "auto",
 | 
				
			||||||
 | 
					        fast_dev_run=args.fast_dev_run,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            pruning,
 | 
					            pruning,
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										12
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								setup.py
									
									
									
									
									
								
							@@ -22,10 +22,10 @@ with open("README.md") as fh:
 | 
				
			|||||||
    long_description = fh.read()
 | 
					    long_description = fh.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
INSTALL_REQUIRES = [
 | 
					INSTALL_REQUIRES = [
 | 
				
			||||||
    "prototorch>=0.7.3",
 | 
					    "prototorch>=0.7.5",
 | 
				
			||||||
    "pytorch_lightning>=1.6.0",
 | 
					    "lightning>=2.0.0",
 | 
				
			||||||
    "torchmetrics<0.10.0",
 | 
					    "torchmetrics",
 | 
				
			||||||
    "protobuf<3.20.0",
 | 
					    "protobuf",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
CLI = [
 | 
					CLI = [
 | 
				
			||||||
    "jsonargparse",
 | 
					    "jsonargparse",
 | 
				
			||||||
@@ -65,7 +65,7 @@ setup(
 | 
				
			|||||||
    url=PROJECT_URL,
 | 
					    url=PROJECT_URL,
 | 
				
			||||||
    download_url=DOWNLOAD_URL,
 | 
					    download_url=DOWNLOAD_URL,
 | 
				
			||||||
    license="MIT",
 | 
					    license="MIT",
 | 
				
			||||||
    python_requires=">=3.7",
 | 
					    python_requires=">=3.8",
 | 
				
			||||||
    install_requires=INSTALL_REQUIRES,
 | 
					    install_requires=INSTALL_REQUIRES,
 | 
				
			||||||
    extras_require={
 | 
					    extras_require={
 | 
				
			||||||
        "dev": DEV,
 | 
					        "dev": DEV,
 | 
				
			||||||
@@ -82,10 +82,10 @@ setup(
 | 
				
			|||||||
        "License :: OSI Approved :: MIT License",
 | 
					        "License :: OSI Approved :: MIT License",
 | 
				
			||||||
        "Natural Language :: English",
 | 
					        "Natural Language :: English",
 | 
				
			||||||
        "Programming Language :: Python :: 3",
 | 
					        "Programming Language :: Python :: 3",
 | 
				
			||||||
 | 
					        "Programming Language :: Python :: 3.11",
 | 
				
			||||||
        "Programming Language :: Python :: 3.10",
 | 
					        "Programming Language :: Python :: 3.10",
 | 
				
			||||||
        "Programming Language :: Python :: 3.9",
 | 
					        "Programming Language :: Python :: 3.9",
 | 
				
			||||||
        "Programming Language :: Python :: 3.8",
 | 
					        "Programming Language :: Python :: 3.8",
 | 
				
			||||||
        "Programming Language :: Python :: 3.7",
 | 
					 | 
				
			||||||
        "Operating System :: OS Independent",
 | 
					        "Operating System :: OS Independent",
 | 
				
			||||||
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
					        "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
				
			||||||
        "Topic :: Software Development :: Libraries",
 | 
					        "Topic :: Software Development :: Libraries",
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user