feat: Improve 2D visualization with Voronoi Cells

This commit is contained in:
Alexander Engelsberger
2021-10-15 13:01:01 +02:00
parent 967953442b
commit d1985571b3
9 changed files with 109 additions and 238 deletions

View File

@@ -38,10 +38,12 @@ if __name__ == "__main__":
)
# Callbacks
vis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example",
resolution=100,
axis_off=True)
vis = pt.models.Visualize2DVoronoiCallback(
data=train_ds,
title="CBC Iris Example",
resolution=100,
axis_off=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(

View File

@@ -3,7 +3,7 @@
import argparse
import prototorch as pt
import prototorch.models.expanded
import prototorch.models.clcc
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
@@ -30,7 +30,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = prototorch.models.expanded.GLVQ(
model = prototorch.models.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
@@ -42,7 +42,13 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
vis = pt.models.Visualize2DVoronoiCallback(
data=train_ds,
resolution=200,
title="Example: GLVQ on Iris",
x_label="sepal length",
y_label="petal length",
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(