chore: Remove PytorchLightning CLI related code
Could be moved in a seperate plugin.
This commit is contained in:
parent
859e2cae69
commit
6ffd27d12a
@ -1,8 +0,0 @@
|
|||||||
# Examples using Lightning CLI
|
|
||||||
|
|
||||||
Examples in this folder use the experimental [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html).
|
|
||||||
|
|
||||||
To use the example run
|
|
||||||
```
|
|
||||||
python gmlvq.py --config gmlvq.yaml
|
|
||||||
```
|
|
@ -1,19 +0,0 @@
|
|||||||
"""GMLVQ example using the MNIST dataset."""
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import torch
|
|
||||||
from prototorch.models import ImageGMLVQ
|
|
||||||
from prototorch.models.abstract import PrototypeModel
|
|
||||||
from prototorch.models.data import MNISTDataModule
|
|
||||||
from pytorch_lightning.utilities.cli import LightningCLI
|
|
||||||
|
|
||||||
|
|
||||||
class ExperimentClass(ImageGMLVQ):
|
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams,
|
|
||||||
optimizer=torch.optim.Adam,
|
|
||||||
prototype_initializer=pt.components.zeros(28 * 28),
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|
|
@ -1,11 +0,0 @@
|
|||||||
model:
|
|
||||||
hparams:
|
|
||||||
input_dim: 784
|
|
||||||
latent_dim: 784
|
|
||||||
distribution:
|
|
||||||
num_classes: 10
|
|
||||||
prototypes_per_class: 2
|
|
||||||
proto_lr: 0.01
|
|
||||||
bb_lr: 0.01
|
|
||||||
data:
|
|
||||||
batch_size: 32
|
|
@ -1,123 +0,0 @@
|
|||||||
"""Prototorch Data Modules
|
|
||||||
|
|
||||||
This allows to store the used dataset inside a Lightning Module.
|
|
||||||
Mainly used for PytorchLightningCLI configurations.
|
|
||||||
"""
|
|
||||||
from typing import Any, Optional, Type
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
from torch.utils.data import DataLoader, Dataset, random_split
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.datasets import MNIST
|
|
||||||
|
|
||||||
|
|
||||||
# MNIST
|
|
||||||
class MNISTDataModule(pl.LightningDataModule):
|
|
||||||
def __init__(self, batch_size=32):
|
|
||||||
super().__init__()
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
# Download mnist dataset as side-effect, only called on the first cpu
|
|
||||||
def prepare_data(self):
|
|
||||||
MNIST("~/datasets", train=True, download=True)
|
|
||||||
MNIST("~/datasets", train=False, download=True)
|
|
||||||
|
|
||||||
# called for every GPU/machine (assigning state is OK)
|
|
||||||
def setup(self, stage=None):
|
|
||||||
# Transforms
|
|
||||||
transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
])
|
|
||||||
# Split dataset
|
|
||||||
if stage in (None, "fit"):
|
|
||||||
mnist_train = MNIST("~/datasets", train=True, transform=transform)
|
|
||||||
self.mnist_train, self.mnist_val = random_split(
|
|
||||||
mnist_train,
|
|
||||||
[55000, 5000],
|
|
||||||
)
|
|
||||||
if stage == (None, "test"):
|
|
||||||
self.mnist_test = MNIST(
|
|
||||||
"~/datasets",
|
|
||||||
train=False,
|
|
||||||
transform=transform,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
def train_dataloader(self):
|
|
||||||
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
|
||||||
return mnist_train
|
|
||||||
|
|
||||||
def val_dataloader(self):
|
|
||||||
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
|
|
||||||
return mnist_val
|
|
||||||
|
|
||||||
def test_dataloader(self):
|
|
||||||
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
|
|
||||||
return mnist_test
|
|
||||||
|
|
||||||
|
|
||||||
# def train_on_mnist(batch_size=256) -> type:
|
|
||||||
# class DataClass(pl.LightningModule):
|
|
||||||
# datamodule = MNISTDataModule(batch_size=batch_size)
|
|
||||||
|
|
||||||
# def __init__(self, *args, **kwargs):
|
|
||||||
# prototype_initializer = kwargs.pop(
|
|
||||||
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
|
||||||
# super().__init__(*args,
|
|
||||||
# prototype_initializer=prototype_initializer,
|
|
||||||
# **kwargs)
|
|
||||||
|
|
||||||
# dc: Type[DataClass] = DataClass
|
|
||||||
# return dc
|
|
||||||
|
|
||||||
|
|
||||||
# ABSTRACT
|
|
||||||
class GeneralDataModule(pl.LightningDataModule):
|
|
||||||
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.train_dataset = dataset
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
def train_dataloader(self) -> DataLoader:
|
|
||||||
return DataLoader(self.train_dataset, batch_size=self.batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
|
|
||||||
# class DataClass(pl.LightningModule):
|
|
||||||
# datamodule = GeneralDataModule(dataset, batch_size)
|
|
||||||
# datashape = dataset[0][0].shape
|
|
||||||
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
|
|
||||||
|
|
||||||
# def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
# prototype_initializer = kwargs.pop(
|
|
||||||
# "prototype_initializer",
|
|
||||||
# pt.components.Zeros(self.datashape),
|
|
||||||
# )
|
|
||||||
# super().__init__(*args,
|
|
||||||
# prototype_initializer=prototype_initializer,
|
|
||||||
# **kwargs)
|
|
||||||
|
|
||||||
# return DataClass
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# from prototorch.models import GLVQ
|
|
||||||
|
|
||||||
# demo_dataset = pt.datasets.Iris()
|
|
||||||
|
|
||||||
# TrainingClass: Type = train_on_dataset(demo_dataset)
|
|
||||||
|
|
||||||
# class DemoGLVQ(TrainingClass, GLVQ):
|
|
||||||
# """Model Definition."""
|
|
||||||
|
|
||||||
# # Hyperparameters
|
|
||||||
# hparams = dict(
|
|
||||||
# distribution={
|
|
||||||
# "num_classes": 3,
|
|
||||||
# "prototypes_per_class": 4
|
|
||||||
# },
|
|
||||||
# lr=0.01,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# initialized = DemoGLVQ(hparams)
|
|
||||||
# print(initialized)
|
|
Loading…
Reference in New Issue
Block a user