GLVQ with configurable distance.
This commit is contained in:
@@ -7,8 +7,7 @@ from matplotlib.offsetbox import AnchoredText
|
||||
|
||||
from prototorch.utils.celluloid import Camera
|
||||
from prototorch.utils.colors import color_scheme
|
||||
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
||||
prettify_string)
|
||||
from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string
|
||||
|
||||
|
||||
class VisWeights(Callback):
|
||||
|
@@ -12,13 +12,19 @@ class GLVQ(pl.LightningModule):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default Values
|
||||
self.hparams.setdefault("distance", euclidean_distance)
|
||||
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=self.hparams.input_dim,
|
||||
nclasses=self.hparams.nclasses,
|
||||
prototypes_per_class=self.hparams.prototypes_per_class,
|
||||
prototype_initializer=self.hparams.prototype_initializer,
|
||||
**kwargs)
|
||||
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
@property
|
||||
@@ -35,7 +41,7 @@ class GLVQ(pl.LightningModule):
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.proto_layer.prototypes
|
||||
dis = euclidean_distance(x, protos)
|
||||
dis = self.hparams.distance(x, protos)
|
||||
return dis
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
|
Reference in New Issue
Block a user