chore: Remove PytorchLightning CLI related code

Could be moved in a seperate plugin.
This commit is contained in:
Alexander Engelsberger 2021-10-11 15:16:12 +02:00
parent 859e2cae69
commit 6ffd27d12a
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
4 changed files with 0 additions and 161 deletions

View File

@ -1,8 +0,0 @@
# Examples using Lightning CLI
Examples in this folder use the experimental [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html).
To use the example run
```
python gmlvq.py --config gmlvq.yaml
```

View File

@ -1,19 +0,0 @@
"""GMLVQ example using the MNIST dataset."""
import prototorch as pt
import torch
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
from pytorch_lightning.utilities.cli import LightningCLI
class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

View File

@ -1,11 +0,0 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
num_classes: 10
prototypes_per_class: 2
proto_lr: 0.01
bb_lr: 0.01
data:
batch_size: 32

View File

@ -1,123 +0,0 @@
"""Prototorch Data Modules
This allows to store the used dataset inside a Lightning Module.
Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import prototorch as pt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
# MNIST
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
# Download mnist dataset as side-effect, only called on the first cpu
def prepare_data(self):
MNIST("~/datasets", train=True, download=True)
MNIST("~/datasets", train=False, download=True)
# called for every GPU/machine (assigning state is OK)
def setup(self, stage=None):
# Transforms
transform = transforms.Compose([
transforms.ToTensor(),
])
# Split dataset
if stage in (None, "fit"):
mnist_train = MNIST("~/datasets", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(
mnist_train,
[55000, 5000],
)
if stage == (None, "test"):
self.mnist_test = MNIST(
"~/datasets",
train=False,
transform=transform,
)
# Dataloaders
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
def test_dataloader(self):
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
return mnist_test
# def train_on_mnist(batch_size=256) -> type:
# class DataClass(pl.LightningModule):
# datamodule = MNISTDataModule(batch_size=batch_size)
# def __init__(self, *args, **kwargs):
# prototype_initializer = kwargs.pop(
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# dc: Type[DataClass] = DataClass
# return dc
# ABSTRACT
class GeneralDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
super().__init__()
self.train_dataset = dataset
self.batch_size = batch_size
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size)
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
# class DataClass(pl.LightningModule):
# datamodule = GeneralDataModule(dataset, batch_size)
# datashape = dataset[0][0].shape
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
# def __init__(self, *args: Any, **kwargs: Any) -> None:
# prototype_initializer = kwargs.pop(
# "prototype_initializer",
# pt.components.Zeros(self.datashape),
# )
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# return DataClass
# if __name__ == "__main__":
# from prototorch.models import GLVQ
# demo_dataset = pt.datasets.Iris()
# TrainingClass: Type = train_on_dataset(demo_dataset)
# class DemoGLVQ(TrainingClass, GLVQ):
# """Model Definition."""
# # Hyperparameters
# hparams = dict(
# distribution={
# "num_classes": 3,
# "prototypes_per_class": 4
# },
# lr=0.01,
# )
# initialized = DemoGLVQ(hparams)
# print(initialized)