GLVQ with configurable distance.

This commit is contained in:
Alexander Engelsberger
2021-04-27 15:38:57 +02:00
parent 1fb197077c
commit eeb684b3b6
3 changed files with 17 additions and 10 deletions

View File

@@ -4,11 +4,11 @@ import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
class VisualizationCallback(pl.Callback):
@@ -57,10 +57,12 @@ class VisualizationCallback(pl.Callback):
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False)
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)