GLVQ with configurable distance.

This commit is contained in:
Alexander Engelsberger 2021-04-27 15:38:57 +02:00
parent 1fb197077c
commit eeb684b3b6
3 changed files with 17 additions and 10 deletions

View File

@ -4,11 +4,11 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
@ -57,10 +57,12 @@ class VisualizationCallback(pl.Callback):
ax.set_xlim(left=x_min + 0, right=x_max - 0) ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0) ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}", tb.add_figure(
figure=self.fig, tag=f"{self.title}",
global_step=trainer.current_epoch, figure=self.fig,
close=False) global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1) plt.pause(0.1)

View File

@ -7,8 +7,7 @@ from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import (gif_from_dir, make_directory, from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string
prettify_string)
class VisWeights(Callback): class VisWeights(Callback):

View File

@ -12,13 +12,19 @@ class GLVQ(pl.LightningModule):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
self.proto_layer = Prototypes1D( self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim, input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses, nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class, prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer, prototype_initializer=self.hparams.prototype_initializer,
**kwargs) **kwargs)
self.train_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy()
@property @property
@ -35,7 +41,7 @@ class GLVQ(pl.LightningModule):
def forward(self, x): def forward(self, x):
protos = self.proto_layer.prototypes protos = self.proto_layer.prototypes
dis = euclidean_distance(x, protos) dis = self.hparams.distance(x, protos)
return dis return dis
def training_step(self, train_batch, batch_idx): def training_step(self, train_batch, batch_idx):