Refactor non-gradient-lvq models into lvq.py

This commit is contained in:
Jensun Ravichandran 2021-05-25 20:37:34 +02:00
parent 32d6f95db0
commit d411e52be4
4 changed files with 96 additions and 89 deletions

View File

@ -2,8 +2,9 @@ from importlib.metadata import PackageNotFoundError, version
from . import probabilistic
from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ)
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ,
ImageGMLVQ, SiameseGLVQ)
from .lvq import LVQ1, LVQ21, MedianLVQ
from .unsupervised import KNN, NeuralGas
from .vis import *

View File

@ -1,4 +1,5 @@
"""Models based on the GLVQ Framework"""
"""Models based on the GLVQ framework."""
import torch
import torchmetrics
from prototorch.components import LabeledComponents
@ -7,8 +8,7 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import (_get_dp_dm, glvq_loss, lvq1_loss,
lvq21_loss)
from prototorch.functions.losses import (glvq_loss, lvq1_loss, lvq21_loss)
from .abstract import AbstractPrototypeModel, PrototypeImageModel
@ -260,78 +260,6 @@ class LVQMLN(SiameseGLVQ):
return distances
class NonGradientGLVQ(GLVQ):
"""Abstract Model for Models that do not use gradients in their update phase."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class LVQ1(NonGradientGLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
d = self._forward(xi.view(1, -1))
preds = wtac(d, plabels)
w = d.argmin(1)
if yi == preds:
shift = xi - protos[w]
else:
shift = protos[w] - xi
updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams.lr * shift)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
self.log_acc(dis, y, tag="train_acc")
return None
class LVQ21(NonGradientGLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
xi = xi.view(1, -1)
yi = yi.view(1, )
d = self._forward(xi)
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
self.log_acc(dis, y, tag="train_acc")
return None
class MedianLVQ(NonGradientGLVQ):
"""Median LVQ"""
class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):

78
prototorch/models/lvq.py Normal file
View File

@ -0,0 +1,78 @@
"""LVQ models that are optimized using non-gradient methods."""
from prototorch.functions.competitions import wtac
from prototorch.functions.losses import _get_dp_dm
from .glvq import GLVQ
class NonGradientLVQ(GLVQ):
"""Abstract Model for Models that do not use gradients in their update phase."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class LVQ1(NonGradientLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
d = self._forward(xi.view(1, -1))
preds = wtac(d, plabels)
w = d.argmin(1)
if yi == preds:
shift = xi - protos[w]
else:
shift = protos[w] - xi
updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams.lr * shift)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
self.log_acc(dis, y, tag="train_acc")
return None
class LVQ21(NonGradientLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
xi = xi.view(1, -1)
yi = yi.view(1, )
d = self._forward(xi)
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
self.log_acc(dis, y, tag="train_acc")
return None
class MedianLVQ(NonGradientLVQ):
"""Median LVQ"""

View File

@ -6,7 +6,7 @@ from .glvq import GLVQ
# HELPER
# TODO: Refactor into general files, if usefull
# TODO: Refactor into general files, if useful
def probability(distance, variance):
return torch.exp(-(distance * distance) / (2 * variance))
@ -14,30 +14,30 @@ def probability(distance, variance):
def grouped_sum(value: torch.Tensor,
labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
"""Group-wise average for (sparse) grouped tensors
Args:
value (torch.Tensor): values to average (# samples, latent dimension)
labels (torch.LongTensor): labels for embedding parameters (# samples,)
Returns:
Returns:
result (torch.Tensor): (# unique labels, latent dimension)
new_labels (torch.LongTensor): (# unique labels,)
Examples:
>>> samples = torch.Tensor([
[0.15, 0.15, 0.15], #-> group / class 1
[0.2, 0.2, 0.2], #-> group / class 3
[0.4, 0.4, 0.4], #-> group / class 3
[0.0, 0.0, 0.0] #-> group / class 0
[0.2, 0.2, 0.2 ], #-> group / class 3
[0.4, 0.4, 0.4 ], #-> group / class 3
[0.0, 0.0, 0.0 ] #-> group / class 0
])
>>> labels = torch.LongTensor([1, 5, 5, 0])
>>> result, new_labels = groupby_mean(samples, labels)
>>> result
tensor([[0.0000, 0.0000, 0.0000],
[0.1500, 0.1500, 0.1500],
[0.3000, 0.3000, 0.3000]])
[0.1500, 0.1500, 0.1500],
[0.3000, 0.3000, 0.3000]])
>>> new_labels
tensor([0, 1, 5])
"""