GLVQ with configurable distance.

This commit is contained in:
Alexander Engelsberger 2021-04-27 15:38:57 +02:00
parent 1fb197077c
commit eeb684b3b6
3 changed files with 17 additions and 10 deletions

View File

@ -4,11 +4,11 @@ import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
class VisualizationCallback(pl.Callback):
@ -57,10 +57,12 @@ class VisualizationCallback(pl.Callback):
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False)
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)

View File

@ -7,8 +7,7 @@ from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import (gif_from_dir, make_directory,
prettify_string)
from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string
class VisWeights(Callback):

View File

@ -12,13 +12,19 @@ class GLVQ(pl.LightningModule):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
self.train_acc = torchmetrics.Accuracy()
@property
@ -35,7 +41,7 @@ class GLVQ(pl.LightningModule):
def forward(self, x):
protos = self.proto_layer.prototypes
dis = euclidean_distance(x, protos)
dis = self.hparams.distance(x, protos)
return dis
def training_step(self, train_batch, batch_idx):