From 72e9587a10a8683e22ebfcc067d0ad0438d54b5f Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 20 Jun 2023 17:30:21 +0200 Subject: [PATCH] fix: remove removed CLI syntax from examples --- examples/cbc_iris.py | 11 +++++++---- examples/dynamic_pruning.py | 11 +++++++---- examples/glvq_iris.py | 11 +++++++---- examples/gmlvq_iris.py | 11 +++++++---- examples/gmlvq_mnist.py | 11 +++++++---- examples/gmlvq_spiral.py | 11 +++++++---- examples/gng_iris.py | 11 +++++++---- examples/grlvq_iris.py | 11 +++++++---- examples/gtlvq_mnist.py | 11 +++++++---- examples/gtlvq_moons.py | 11 +++++++---- examples/knn_iris.py | 9 ++++++--- examples/ksom_colors.py | 11 +++++++---- examples/lgmlvq_moons.py | 11 +++++++---- examples/lvqmln_iris.py | 11 +++++++---- examples/median_lvq_iris.py | 11 +++++++---- examples/ng_iris.py | 11 +++++++---- examples/rslvq_iris.py | 11 +++++++---- examples/siamese_glvq_iris.py | 11 +++++++---- examples/siamese_gtlvq_iris.py | 11 +++++++---- examples/warm_starting.py | 11 +++++++---- setup.py | 12 ++++++------ 21 files changed, 145 insertions(+), 85 deletions(-) diff --git a/examples/cbc_iris.py b/examples/cbc_iris.py index 30610af..47b3e3f 100644 --- a/examples/cbc_iris.py +++ b/examples/cbc_iris.py @@ -5,8 +5,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import CBC, VisCBC2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -19,7 +19,8 @@ if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -53,8 +54,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/dynamic_pruning.py b/examples/dynamic_pruning.py index 275cdff..c820a32 100644 --- a/examples/dynamic_pruning.py +++ b/examples/dynamic_pruning.py @@ -7,13 +7,13 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import ( CELVQ, PruneLoserPrototypes, VisGLVQ2D, ) from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -26,7 +26,8 @@ if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -83,8 +84,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, pruning, diff --git a/examples/glvq_iris.py b/examples/glvq_iris.py index d23e82a..d1a71a7 100644 --- a/examples/glvq_iris.py +++ b/examples/glvq_iris.py @@ -7,8 +7,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import GLVQ, VisGLVQ2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR from torch.utils.data import DataLoader @@ -21,7 +21,8 @@ if __name__ == "__main__": seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -55,8 +56,10 @@ if __name__ == "__main__": vis = VisGLVQ2D(data=train_ds) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/gmlvq_iris.py b/examples/gmlvq_iris.py index b713983..57b7e04 100644 --- a/examples/gmlvq_iris.py +++ b/examples/gmlvq_iris.py @@ -6,8 +6,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import GMLVQ, VisGMLVQ2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR from torch.utils.data import DataLoader @@ -22,7 +22,8 @@ if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -59,8 +60,10 @@ if __name__ == "__main__": vis = VisGMLVQ2D(data=train_ds) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/gmlvq_mnist.py b/examples/gmlvq_mnist.py index 64eef26..699771e 100644 --- a/examples/gmlvq_mnist.py +++ b/examples/gmlvq_mnist.py @@ -6,13 +6,13 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import ( ImageGMLVQ, PruneLoserPrototypes, VisImgComp, ) from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader from torchvision import transforms @@ -26,7 +26,8 @@ if __name__ == "__main__": seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -96,8 +97,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, pruning, diff --git a/examples/gmlvq_spiral.py b/examples/gmlvq_spiral.py index 4cd2e7d..29eecec 100644 --- a/examples/gmlvq_spiral.py +++ b/examples/gmlvq_spiral.py @@ -6,13 +6,13 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import ( GMLVQ, PruneLoserPrototypes, VisGLVQ2D, ) from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -25,7 +25,8 @@ if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -78,8 +79,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, es, diff --git a/examples/gng_iris.py b/examples/gng_iris.py index 0783b04..aec888c 100644 --- a/examples/gng_iris.py +++ b/examples/gng_iris.py @@ -7,8 +7,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import GrowingNeuralGas, VisNG2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Reproducibility @@ -51,8 +52,10 @@ if __name__ == "__main__": vis = VisNG2D(data=train_loader) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/grlvq_iris.py b/examples/grlvq_iris.py index 97e0a0c..045ad46 100644 --- a/examples/grlvq_iris.py +++ b/examples/grlvq_iris.py @@ -6,8 +6,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import GRLVQ, VisSiameseGLVQ2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR from torch.utils.data import DataLoader @@ -22,7 +22,8 @@ if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -58,8 +59,10 @@ if __name__ == "__main__": vis = VisSiameseGLVQ2D(data=train_ds) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/gtlvq_mnist.py b/examples/gtlvq_mnist.py index e5af4eb..ca4bfb1 100644 --- a/examples/gtlvq_mnist.py +++ b/examples/gtlvq_mnist.py @@ -6,13 +6,13 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import ( ImageGTLVQ, PruneLoserPrototypes, VisImgComp, ) from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader from torchvision import transforms @@ -27,7 +27,8 @@ if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -100,8 +101,10 @@ if __name__ == "__main__": # Setup trainer # using GPUs here is strongly recommended! - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, pruning, diff --git a/examples/gtlvq_moons.py b/examples/gtlvq_moons.py index 7e8368b..70e932f 100644 --- a/examples/gtlvq_moons.py +++ b/examples/gtlvq_moons.py @@ -7,9 +7,9 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import GTLVQ, VisGLVQ2D from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Reproducibility @@ -61,8 +62,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, es, diff --git a/examples/knn_iris.py b/examples/knn_iris.py index c1b3279..d531a2e 100644 --- a/examples/knn_iris.py +++ b/examples/knn_iris.py @@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning) if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -59,8 +60,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, max_epochs=1, callbacks=[ vis, diff --git a/examples/ksom_colors.py b/examples/ksom_colors.py index bd2ded0..47abcfa 100644 --- a/examples/ksom_colors.py +++ b/examples/ksom_colors.py @@ -7,10 +7,10 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from matplotlib import pyplot as plt from prototorch.models import KohonenSOM from prototorch.utils.colors import hex_to_rgb -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader, TensorDataset @@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback): if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Reproducibility @@ -104,8 +105,10 @@ if __name__ == "__main__": vis = Vis2DColorSOM(data=data) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, max_epochs=500, callbacks=[ vis, diff --git a/examples/lgmlvq_moons.py b/examples/lgmlvq_moons.py index 155bf10..384d759 100644 --- a/examples/lgmlvq_moons.py +++ b/examples/lgmlvq_moons.py @@ -7,9 +7,9 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import LGMLVQ, VisGLVQ2D from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Reproducibility @@ -62,8 +63,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, es, diff --git a/examples/lvqmln_iris.py b/examples/lvqmln_iris.py index 42c8b7b..44c3285 100644 --- a/examples/lvqmln_iris.py +++ b/examples/lvqmln_iris.py @@ -6,12 +6,12 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import ( LVQMLN, PruneLoserPrototypes, VisSiameseGLVQ2D, ) -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -39,7 +39,8 @@ class Backbone(torch.nn.Module): if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -88,8 +89,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, pruning, diff --git a/examples/median_lvq_iris.py b/examples/median_lvq_iris.py index d709f80..f23177e 100644 --- a/examples/median_lvq_iris.py +++ b/examples/median_lvq_iris.py @@ -6,9 +6,9 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import MedianLVQ, VisGLVQ2D from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -20,7 +20,8 @@ if __name__ == "__main__": seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -53,8 +54,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, es, diff --git a/examples/ng_iris.py b/examples/ng_iris.py index c897fd0..f5157ea 100644 --- a/examples/ng_iris.py +++ b/examples/ng_iris.py @@ -6,8 +6,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import NeuralGas, VisNG2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler @@ -23,7 +23,8 @@ if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Prepare and pre-process the dataset @@ -60,8 +61,10 @@ if __name__ == "__main__": vis = VisNG2D(data=train_ds) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/rslvq_iris.py b/examples/rslvq_iris.py index c561cbe..b343d34 100644 --- a/examples/rslvq_iris.py +++ b/examples/rslvq_iris.py @@ -6,8 +6,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import RSLVQ, VisGLVQ2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning) if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Reproducibility @@ -54,8 +55,10 @@ if __name__ == "__main__": vis = VisGLVQ2D(data=train_ds) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index 19f990e..b00e308 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -6,8 +6,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -35,7 +35,8 @@ class Backbone(torch.nn.Module): if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -69,8 +70,10 @@ if __name__ == "__main__": vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/siamese_gtlvq_iris.py b/examples/siamese_gtlvq_iris.py index 16db174..4f036d1 100644 --- a/examples/siamese_gtlvq_iris.py +++ b/examples/siamese_gtlvq_iris.py @@ -6,8 +6,8 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader @@ -35,7 +35,8 @@ class Backbone(torch.nn.Module): if __name__ == "__main__": # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Dataset @@ -71,8 +72,10 @@ if __name__ == "__main__": vis = VisSiameseGLVQ2D(data=train_ds, border=0.1) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, ], diff --git a/examples/warm_starting.py b/examples/warm_starting.py index 98d0668..a9e17dc 100644 --- a/examples/warm_starting.py +++ b/examples/warm_starting.py @@ -6,6 +6,7 @@ import warnings import prototorch as pt import pytorch_lightning as pl import torch +from lightning_fabric.utilities.seed import seed_everything from prototorch.models import ( GLVQ, KNN, @@ -14,7 +15,6 @@ from prototorch.models import ( VisGLVQ2D, ) from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.optim.lr_scheduler import ExponentialLR from torch.utils.data import DataLoader @@ -27,7 +27,8 @@ if __name__ == "__main__": seed_everything(seed=4) # Command-line arguments parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) + parser.add_argument("--gpus", type=int, default=0) + parser.add_argument("--fast_dev_run", type=bool, default=False) args = parser.parse_args() # Prepare the data @@ -108,8 +109,10 @@ if __name__ == "__main__": ) # Setup trainer - trainer = pl.Trainer.from_argparse_args( - args, + trainer = pl.Trainer( + accelerator="cuda" if args.gpus else "cpu", + devices=args.gpus if args.gpus else "auto", + fast_dev_run=args.fast_dev_run, callbacks=[ vis, pruning, diff --git a/setup.py b/setup.py index bb4b478..9e56977 100644 --- a/setup.py +++ b/setup.py @@ -22,10 +22,10 @@ with open("README.md") as fh: long_description = fh.read() INSTALL_REQUIRES = [ - "prototorch>=0.7.3", - "pytorch_lightning>=1.6.0", - "torchmetrics<0.10.0", - "protobuf<3.20.0", + "prototorch>=0.7.5", + "lightning>=2.0.0", + "torchmetrics", + "protobuf", ] CLI = [ "jsonargparse", @@ -65,7 +65,7 @@ setup( url=PROJECT_URL, download_url=DOWNLOAD_URL, license="MIT", - python_requires=">=3.7", + python_requires=">=3.8", install_requires=INSTALL_REQUIRES, extras_require={ "dev": DEV, @@ -82,10 +82,10 @@ setup( "License :: OSI Approved :: MIT License", "Natural Language :: English", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.7", "Operating System :: OS Independent", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries",