fix: remove removed CLI syntax from examples

This commit is contained in:
Alexander Engelsberger 2023-06-20 17:30:21 +02:00
parent f5e1edf31f
commit 72e9587a10
No known key found for this signature in database
21 changed files with 145 additions and 85 deletions

View File

@ -5,8 +5,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import CBC, VisCBC2D from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -19,7 +19,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -53,8 +54,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -7,13 +7,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
CELVQ, CELVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -26,7 +26,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -83,8 +84,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -7,8 +7,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GLVQ, VisGLVQ2D from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -21,7 +21,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -55,8 +56,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GMLVQ, VisGMLVQ2D from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -22,7 +22,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -59,8 +60,10 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds) vis = VisGMLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
ImageGMLVQ, ImageGMLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisImgComp, VisImgComp,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
@ -26,7 +26,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -96,8 +97,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
GMLVQ, GMLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -25,7 +25,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -78,8 +79,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -7,8 +7,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GrowingNeuralGas, VisNG2D from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -51,8 +52,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_loader) vis = VisNG2D(data=train_loader)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GRLVQ, VisSiameseGLVQ2D from prototorch.models import GRLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -22,7 +22,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -58,8 +59,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds) vis = VisSiameseGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,13 +6,13 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
ImageGTLVQ, ImageGTLVQ,
PruneLoserPrototypes, PruneLoserPrototypes,
VisImgComp, VisImgComp,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
@ -27,7 +27,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -100,8 +101,10 @@ if __name__ == "__main__":
# Setup trainer # Setup trainer
# using GPUs here is strongly recommended! # using GPUs here is strongly recommended!
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -7,9 +7,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import GTLVQ, VisGLVQ2D from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -61,8 +62,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -59,8 +60,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=1, max_epochs=1,
callbacks=[ callbacks=[
vis, vis,

View File

@ -7,10 +7,10 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -104,8 +105,10 @@ if __name__ == "__main__":
vis = Vis2DColorSOM(data=data) vis = Vis2DColorSOM(data=data)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
max_epochs=500, max_epochs=500,
callbacks=[ callbacks=[
vis, vis,

View File

@ -7,9 +7,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import LGMLVQ, VisGLVQ2D from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -62,8 +63,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -6,12 +6,12 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
LVQMLN, LVQMLN,
PruneLoserPrototypes, PruneLoserPrototypes,
VisSiameseGLVQ2D, VisSiameseGLVQ2D,
) )
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -39,7 +39,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -88,8 +89,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -6,9 +6,9 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import MedianLVQ, VisGLVQ2D from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -20,7 +20,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -53,8 +54,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
es, es,

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import NeuralGas, VisNG2D from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
@ -23,7 +23,8 @@ if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Prepare and pre-process the dataset # Prepare and pre-process the dataset
@ -60,8 +61,10 @@ if __name__ == "__main__":
vis = VisNG2D(data=train_ds) vis = VisNG2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import RSLVQ, VisGLVQ2D from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
@ -54,8 +55,10 @@ if __name__ == "__main__":
vis = VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -69,8 +70,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,8 +6,8 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
@ -71,8 +72,10 @@ if __name__ == "__main__":
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
], ],

View File

@ -6,6 +6,7 @@ import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from lightning_fabric.utilities.seed import seed_everything
from prototorch.models import ( from prototorch.models import (
GLVQ, GLVQ,
KNN, KNN,
@ -14,7 +15,6 @@ from prototorch.models import (
VisGLVQ2D, VisGLVQ2D,
) )
from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -27,7 +27,8 @@ if __name__ == "__main__":
seed_everything(seed=4) seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--gpus", type=int, default=0)
parser.add_argument("--fast_dev_run", type=bool, default=False)
args = parser.parse_args() args = parser.parse_args()
# Prepare the data # Prepare the data
@ -108,8 +109,10 @@ if __name__ == "__main__":
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, accelerator="cuda" if args.gpus else "cpu",
devices=args.gpus if args.gpus else "auto",
fast_dev_run=args.fast_dev_run,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -22,10 +22,10 @@ with open("README.md") as fh:
long_description = fh.read() long_description = fh.read()
INSTALL_REQUIRES = [ INSTALL_REQUIRES = [
"prototorch>=0.7.3", "prototorch>=0.7.5",
"pytorch_lightning>=1.6.0", "lightning>=2.0.0",
"torchmetrics<0.10.0", "torchmetrics",
"protobuf<3.20.0", "protobuf",
] ]
CLI = [ CLI = [
"jsonargparse", "jsonargparse",
@ -65,7 +65,7 @@ setup(
url=PROJECT_URL, url=PROJECT_URL,
download_url=DOWNLOAD_URL, download_url=DOWNLOAD_URL,
license="MIT", license="MIT",
python_requires=">=3.7", python_requires=">=3.8",
install_requires=INSTALL_REQUIRES, install_requires=INSTALL_REQUIRES,
extras_require={ extras_require={
"dev": DEV, "dev": DEV,
@ -82,10 +82,10 @@ setup(
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Natural Language :: English", "Natural Language :: English",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",