GLVQ with configurable distance.
This commit is contained in:
parent
1fb197077c
commit
eeb684b3b6
@ -4,11 +4,11 @@ import numpy as np
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.datasets.abstract import NumpyDataset
|
|
||||||
from prototorch.models.glvq import SiameseGLVQ
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
from prototorch.datasets.abstract import NumpyDataset
|
||||||
|
from prototorch.models.glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
class VisualizationCallback(pl.Callback):
|
class VisualizationCallback(pl.Callback):
|
||||||
@ -57,10 +57,12 @@ class VisualizationCallback(pl.Callback):
|
|||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||||
tb = pl_module.logger.experiment
|
tb = pl_module.logger.experiment
|
||||||
tb.add_figure(tag=f"{self.title}",
|
tb.add_figure(
|
||||||
figure=self.fig,
|
tag=f"{self.title}",
|
||||||
global_step=trainer.current_epoch,
|
figure=self.fig,
|
||||||
close=False)
|
global_step=trainer.current_epoch,
|
||||||
|
close=False,
|
||||||
|
)
|
||||||
plt.pause(0.1)
|
plt.pause(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,7 @@ from matplotlib.offsetbox import AnchoredText
|
|||||||
|
|
||||||
from prototorch.utils.celluloid import Camera
|
from prototorch.utils.celluloid import Camera
|
||||||
from prototorch.utils.colors import color_scheme
|
from prototorch.utils.colors import color_scheme
|
||||||
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string
|
||||||
prettify_string)
|
|
||||||
|
|
||||||
|
|
||||||
class VisWeights(Callback):
|
class VisWeights(Callback):
|
||||||
|
@ -12,13 +12,19 @@ class GLVQ(pl.LightningModule):
|
|||||||
"""Generalized Learning Vector Quantization."""
|
"""Generalized Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
# Default Values
|
||||||
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
|
|
||||||
self.proto_layer = Prototypes1D(
|
self.proto_layer = Prototypes1D(
|
||||||
input_dim=self.hparams.input_dim,
|
input_dim=self.hparams.input_dim,
|
||||||
nclasses=self.hparams.nclasses,
|
nclasses=self.hparams.nclasses,
|
||||||
prototypes_per_class=self.hparams.prototypes_per_class,
|
prototypes_per_class=self.hparams.prototypes_per_class,
|
||||||
prototype_initializer=self.hparams.prototype_initializer,
|
prototype_initializer=self.hparams.prototype_initializer,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
self.train_acc = torchmetrics.Accuracy()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -35,7 +41,7 @@ class GLVQ(pl.LightningModule):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.proto_layer.prototypes
|
protos = self.proto_layer.prototypes
|
||||||
dis = euclidean_distance(x, protos)
|
dis = self.hparams.distance(x, protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx):
|
||||||
|
Loading…
Reference in New Issue
Block a user