Update example scripts
This commit is contained in:
parent
ee39ac516d
commit
985cdd3120
@ -60,12 +60,15 @@ class VisualizationCallback(pl.Callback):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Dataset
|
||||||
x_train, y_train = load_iris(return_X_y=True)
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
x_train = x_train[:, [0, 2]]
|
x_train = x_train[:, [0, 2]]
|
||||||
|
|
||||||
train_ds = NumpyDataset(x_train, y_train)
|
train_ds = NumpyDataset(x_train, y_train)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
|
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
model = GLVQ(
|
model = GLVQ(
|
||||||
input_dim=x_train.shape[1],
|
input_dim=x_train.shape[1],
|
||||||
nclasses=3,
|
nclasses=3,
|
||||||
@ -74,13 +77,20 @@ if __name__ == "__main__":
|
|||||||
data=[x_train, y_train],
|
data=[x_train, y_train],
|
||||||
lr=0.1,
|
lr=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Model summary
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
vis = VisualizationCallback(x_train, y_train)
|
vis = VisualizationCallback(x_train, y_train)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=1000, callbacks=[vis])
|
trainer = pl.Trainer(max_epochs=1000, callbacks=[vis])
|
||||||
|
|
||||||
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Visualization
|
||||||
protos = model.prototypes
|
protos = model.prototypes
|
||||||
plabels = model.prototype_labels
|
plabels = model.prototype_labels
|
||||||
visualize(x_train, y_train, protos, plabels)
|
visualize(x_train, y_train, protos, plabels)
|
||||||
|
@ -1,44 +1,85 @@
|
|||||||
|
"""GLVQ example using the MNIST dataset.
|
||||||
|
|
||||||
|
This script also shows how to use Tensorboard for visualizing the prototypes.
|
||||||
|
"""
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torchvision
|
import torchvision
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.functions.initializers import stratified_mean
|
from prototorch.functions.initializers import stratified_mean
|
||||||
from prototorch.models.glvq import ImageGLVQ
|
from prototorch.models.glvq import ImageGLVQ
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
|
|
||||||
def plot_protos(protos, shape=(-1, 1, 28, 28), nrow=2):
|
class VisualizationCallback(pl.Callback):
|
||||||
grid = torchvision.utils.make_grid(protos.reshape(*shape), nrow=nrow)
|
def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
|
||||||
grid = grid.permute((1, 2, 0))
|
super().__init__()
|
||||||
plt.imshow(grid)
|
self.to_shape = to_shape
|
||||||
|
self.nrow = nrow
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
protos = pl_module.proto_layer.prototypes.detach().cpu()
|
||||||
|
protos_img = protos.reshape(self.to_shape)
|
||||||
|
grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow)
|
||||||
|
# grid = grid.permute((1, 2, 0))
|
||||||
|
tb = pl_module.logger.experiment
|
||||||
|
tb.add_image(tag="MNIST Prototypes",
|
||||||
|
img_tensor=grid,
|
||||||
|
global_step=trainer.current_epoch,
|
||||||
|
dataformats="CHW")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dataset = MNIST("./datasets",
|
# Dataset
|
||||||
train=True,
|
mnist_train = MNIST(
|
||||||
download=True,
|
"./datasets",
|
||||||
transform=transforms.ToTensor())
|
train=True,
|
||||||
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
|
download=True,
|
||||||
|
transform=transforms.Compose([
|
||||||
train_loader = DataLoader(mnist_train, batch_size=1024)
|
transforms.ToTensor(),
|
||||||
val_loader = DataLoader(mnist_val, batch_size=1024)
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
|
]),
|
||||||
model = ImageGLVQ(input_dim=28 * 28, nclasses=10, prototypes_per_class=2)
|
)
|
||||||
|
mnist_test = MNIST(
|
||||||
# Warm-start prototypes
|
"./datasets",
|
||||||
prototypes, prototype_labels = stratified_mean(
|
train=False,
|
||||||
x_train,
|
download=True,
|
||||||
y_train,
|
transform=transforms.Compose([
|
||||||
prototype_distribution=self.prototype_distribution,
|
transforms.ToTensor(),
|
||||||
one_hot=one_hot_labels,
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = pl.Trainer(gpus=0, max_epochs=3)
|
# Dataloaders
|
||||||
|
train_loader = DataLoader(mnist_train, batch_size=1024)
|
||||||
|
test_loader = DataLoader(mnist_test, batch_size=1024)
|
||||||
|
|
||||||
trainer.fit(model, train_loader, val_loader)
|
# Grab the full dataset to warm-start prototypes
|
||||||
|
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
|
||||||
|
x = x.view(len(mnist_train), -1)
|
||||||
|
|
||||||
protos = model.proto_layer.prototypes.detach().cpu()
|
# Initialize the model
|
||||||
plot_protos(protos, shape=(-1, 1, 28, 28), nrow=4)
|
model = ImageGLVQ(input_dim=28 * 28,
|
||||||
plt.show(block=True)
|
nclasses=10,
|
||||||
|
prototypes_per_class=10,
|
||||||
|
prototype_initializer="stratified_mean",
|
||||||
|
data=[x, y])
|
||||||
|
# Model summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=10)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
gpus=0, # change to use GPUs for training
|
||||||
|
max_epochs=10,
|
||||||
|
callbacks=[vis],
|
||||||
|
# accelerator="ddp_cpu", # DEBUG-ONLY
|
||||||
|
# num_processes=2, # DEBUG-ONLY
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader, test_loader)
|
||||||
|
Loading…
Reference in New Issue
Block a user