36 lines
1001 B
Python
36 lines
1001 B
Python
import matplotlib.pyplot as plt
|
|
import prototorch as pt
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
|
|
from prototorch.models.expanded.clcc_glvq import GLVQ, GLVQhparams
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision import datasets
|
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
|
|
|
plt.gray()
|
|
|
|
if __name__ == "__main__":
|
|
# Dataset
|
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
|
# Dataloaders
|
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
|
|
|
components_initializer = SMCI(train_ds)
|
|
|
|
hparams = GLVQhparams(
|
|
distribution=dict(
|
|
num_classes=3,
|
|
per_class=2,
|
|
),
|
|
component_initializer=components_initializer,
|
|
)
|
|
model = GLVQ(hparams)
|
|
|
|
print(model)
|
|
# Callbacks
|
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
|
# Train
|
|
trainer = pl.Trainer(callbacks=[vis], gpus=1)
|
|
trainer.fit(model, train_loader)
|