fix: remove removed CLI syntax from examples
This commit is contained in:
parent
f5e1edf31f
commit
72e9587a10
@ -5,8 +5,8 @@ import warnings
|
|||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import CBC, VisCBC2D
|
from prototorch.models import CBC, VisCBC2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -19,7 +19,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -53,8 +54,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -7,13 +7,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
CELVQ,
|
CELVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisGLVQ2D,
|
VisGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -26,7 +26,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -83,8 +84,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -7,8 +7,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GLVQ, VisGLVQ2D
|
from prototorch.models import GLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -21,7 +21,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -55,8 +56,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GMLVQ, VisGMLVQ2D
|
from prototorch.models import GMLVQ, VisGMLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -22,7 +22,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -59,8 +60,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisGMLVQ2D(data=train_ds)
|
vis = VisGMLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,13 +6,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
ImageGMLVQ,
|
ImageGMLVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisImgComp,
|
VisImgComp,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@ -26,7 +26,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -96,8 +97,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -6,13 +6,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
GMLVQ,
|
GMLVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisGLVQ2D,
|
VisGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -25,7 +25,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -78,8 +79,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -7,8 +7,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GrowingNeuralGas, VisNG2D
|
from prototorch.models import GrowingNeuralGas, VisNG2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -51,8 +52,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisNG2D(data=train_loader)
|
vis = VisNG2D(data=train_loader)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
|
from prototorch.models import GRLVQ, VisSiameseGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -22,7 +22,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -58,8 +59,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisSiameseGLVQ2D(data=train_ds)
|
vis = VisSiameseGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,13 +6,13 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
ImageGTLVQ,
|
ImageGTLVQ,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisImgComp,
|
VisImgComp,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@ -27,7 +27,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -100,8 +101,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
# using GPUs here is strongly recommended!
|
# using GPUs here is strongly recommended!
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -7,9 +7,9 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import GTLVQ, VisGLVQ2D
|
from prototorch.models import GTLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -61,8 +62,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -18,7 +18,8 @@ warnings.filterwarnings("ignore", category=PossibleUserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -59,8 +60,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
max_epochs=1,
|
max_epochs=1,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
|
@ -7,10 +7,10 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.models import KohonenSOM
|
from prototorch.models import KohonenSOM
|
||||||
from prototorch.utils.colors import hex_to_rgb
|
from prototorch.utils.colors import hex_to_rgb
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
@ -58,7 +58,8 @@ class Vis2DColorSOM(pl.Callback):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -104,8 +105,10 @@ if __name__ == "__main__":
|
|||||||
vis = Vis2DColorSOM(data=data)
|
vis = Vis2DColorSOM(data=data)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
max_epochs=500,
|
max_epochs=500,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
|
@ -7,9 +7,9 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import LGMLVQ, VisGLVQ2D
|
from prototorch.models import LGMLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -19,7 +19,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -62,8 +63,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -6,12 +6,12 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
LVQMLN,
|
LVQMLN,
|
||||||
PruneLoserPrototypes,
|
PruneLoserPrototypes,
|
||||||
VisSiameseGLVQ2D,
|
VisSiameseGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -39,7 +39,8 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -88,8 +89,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -6,9 +6,9 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import MedianLVQ, VisGLVQ2D
|
from prototorch.models import MedianLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -20,7 +20,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -53,8 +54,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
es,
|
es,
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import NeuralGas, VisNG2D
|
from prototorch.models import NeuralGas, VisNG2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
@ -23,7 +23,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Prepare and pre-process the dataset
|
# Prepare and pre-process the dataset
|
||||||
@ -60,8 +61,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisNG2D(data=train_ds)
|
vis = VisNG2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import RSLVQ, VisGLVQ2D
|
from prototorch.models import RSLVQ, VisGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -17,7 +17,8 @@ warnings.filterwarnings("ignore", category=UserWarning)
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
@ -54,8 +55,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisGLVQ2D(data=train_ds)
|
vis = VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
|
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -69,8 +70,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,8 +6,8 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
|
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -35,7 +35,8 @@ class Backbone(torch.nn.Module):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -71,8 +72,10 @@ if __name__ == "__main__":
|
|||||||
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
],
|
],
|
||||||
|
@ -6,6 +6,7 @@ import warnings
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from lightning_fabric.utilities.seed import seed_everything
|
||||||
from prototorch.models import (
|
from prototorch.models import (
|
||||||
GLVQ,
|
GLVQ,
|
||||||
KNN,
|
KNN,
|
||||||
@ -14,7 +15,6 @@ from prototorch.models import (
|
|||||||
VisGLVQ2D,
|
VisGLVQ2D,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
|
||||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -27,7 +27,8 @@ if __name__ == "__main__":
|
|||||||
seed_everything(seed=4)
|
seed_everything(seed=4)
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
parser.add_argument("--gpus", type=int, default=0)
|
||||||
|
parser.add_argument("--fast_dev_run", type=bool, default=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
@ -108,8 +109,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer(
|
||||||
args,
|
accelerator="cuda" if args.gpus else "cpu",
|
||||||
|
devices=args.gpus if args.gpus else "auto",
|
||||||
|
fast_dev_run=args.fast_dev_run,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
12
setup.py
12
setup.py
@ -22,10 +22,10 @@ with open("README.md") as fh:
|
|||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"prototorch>=0.7.3",
|
"prototorch>=0.7.5",
|
||||||
"pytorch_lightning>=1.6.0",
|
"lightning>=2.0.0",
|
||||||
"torchmetrics<0.10.0",
|
"torchmetrics",
|
||||||
"protobuf<3.20.0",
|
"protobuf",
|
||||||
]
|
]
|
||||||
CLI = [
|
CLI = [
|
||||||
"jsonargparse",
|
"jsonargparse",
|
||||||
@ -65,7 +65,7 @@ setup(
|
|||||||
url=PROJECT_URL,
|
url=PROJECT_URL,
|
||||||
download_url=DOWNLOAD_URL,
|
download_url=DOWNLOAD_URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.8",
|
||||||
install_requires=INSTALL_REQUIRES,
|
install_requires=INSTALL_REQUIRES,
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": DEV,
|
"dev": DEV,
|
||||||
@ -82,10 +82,10 @@ setup(
|
|||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Natural Language :: English",
|
"Natural Language :: English",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.7",
|
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
|
Loading…
Reference in New Issue
Block a user