GLVQ with configurable distance.
This commit is contained in:
@@ -4,11 +4,11 @@ import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.datasets.abstract import NumpyDataset
|
||||
from prototorch.models.glvq import SiameseGLVQ
|
||||
from sklearn.datasets import load_iris
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from prototorch.datasets.abstract import NumpyDataset
|
||||
from prototorch.models.glvq import SiameseGLVQ
|
||||
|
||||
|
||||
class VisualizationCallback(pl.Callback):
|
||||
@@ -57,10 +57,12 @@ class VisualizationCallback(pl.Callback):
|
||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||
tb = pl_module.logger.experiment
|
||||
tb.add_figure(tag=f"{self.title}",
|
||||
figure=self.fig,
|
||||
global_step=trainer.current_epoch,
|
||||
close=False)
|
||||
tb.add_figure(
|
||||
tag=f"{self.title}",
|
||||
figure=self.fig,
|
||||
global_step=trainer.current_epoch,
|
||||
close=False,
|
||||
)
|
||||
plt.pause(0.1)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user