diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 653a0c3..bd8297e 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -2,8 +2,9 @@ from importlib.metadata import PackageNotFoundError, version from . import probabilistic from .cbc import CBC, ImageCBC -from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, - ImageGLVQ, ImageGMLVQ, SiameseGLVQ) +from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, + ImageGMLVQ, SiameseGLVQ) +from .lvq import LVQ1, LVQ21, MedianLVQ from .unsupervised import KNN, NeuralGas from .vis import * diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 78371e3..d681d28 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -1,4 +1,5 @@ -"""Models based on the GLVQ Framework""" +"""Models based on the GLVQ framework.""" + import torch import torchmetrics from prototorch.components import LabeledComponents @@ -7,8 +8,7 @@ from prototorch.functions.competitions import wtac from prototorch.functions.distances import (euclidean_distance, omega_distance, sed) from prototorch.functions.helper import get_flat -from prototorch.functions.losses import (_get_dp_dm, glvq_loss, lvq1_loss, - lvq21_loss) +from prototorch.functions.losses import (glvq_loss, lvq1_loss, lvq21_loss) from .abstract import AbstractPrototypeModel, PrototypeImageModel @@ -260,78 +260,6 @@ class LVQMLN(SiameseGLVQ): return distances -class NonGradientGLVQ(GLVQ): - """Abstract Model for Models that do not use gradients in their update phase.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.automatic_optimization = False - - def training_step(self, train_batch, batch_idx, optimizer_idx=None): - raise NotImplementedError - - -class LVQ1(NonGradientGLVQ): - """Learning Vector Quantization 1.""" - def training_step(self, train_batch, batch_idx, optimizer_idx=None): - protos = self.proto_layer.components - plabels = self.proto_layer.component_labels - - x, y = train_batch - dis = self._forward(x) - # TODO Vectorized implementation - - for xi, yi in zip(x, y): - d = self._forward(xi.view(1, -1)) - preds = wtac(d, plabels) - w = d.argmin(1) - if yi == preds: - shift = xi - protos[w] - else: - shift = protos[w] - xi - updated_protos = protos + 0.0 - updated_protos[w] = protos[w] + (self.hparams.lr * shift) - self.proto_layer.load_state_dict({"_components": updated_protos}, - strict=False) - - # Logging - self.log_acc(dis, y, tag="train_acc") - - return None - - -class LVQ21(NonGradientGLVQ): - """Learning Vector Quantization 2.1.""" - def training_step(self, train_batch, batch_idx, optimizer_idx=None): - protos = self.proto_layer.components - plabels = self.proto_layer.component_labels - - x, y = train_batch - dis = self._forward(x) - # TODO Vectorized implementation - - for xi, yi in zip(x, y): - xi = xi.view(1, -1) - yi = yi.view(1, ) - d = self._forward(xi) - (_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True) - shiftp = xi - protos[wp] - shiftn = protos[wn] - xi - updated_protos = protos + 0.0 - updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp) - updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn) - self.proto_layer.load_state_dict({"_components": updated_protos}, - strict=False) - - # Logging - self.log_acc(dis, y, tag="train_acc") - - return None - - -class MedianLVQ(NonGradientGLVQ): - """Median LVQ""" - - class GLVQ1(GLVQ): """Generalized Learning Vector Quantization 1.""" def __init__(self, hparams, **kwargs): diff --git a/prototorch/models/lvq.py b/prototorch/models/lvq.py new file mode 100644 index 0000000..716f9bf --- /dev/null +++ b/prototorch/models/lvq.py @@ -0,0 +1,78 @@ +"""LVQ models that are optimized using non-gradient methods.""" + +from prototorch.functions.competitions import wtac +from prototorch.functions.losses import _get_dp_dm + +from .glvq import GLVQ + + +class NonGradientLVQ(GLVQ): + """Abstract Model for Models that do not use gradients in their update phase.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.automatic_optimization = False + + def training_step(self, train_batch, batch_idx, optimizer_idx=None): + raise NotImplementedError + + +class LVQ1(NonGradientLVQ): + """Learning Vector Quantization 1.""" + def training_step(self, train_batch, batch_idx, optimizer_idx=None): + protos = self.proto_layer.components + plabels = self.proto_layer.component_labels + + x, y = train_batch + dis = self._forward(x) + # TODO Vectorized implementation + + for xi, yi in zip(x, y): + d = self._forward(xi.view(1, -1)) + preds = wtac(d, plabels) + w = d.argmin(1) + if yi == preds: + shift = xi - protos[w] + else: + shift = protos[w] - xi + updated_protos = protos + 0.0 + updated_protos[w] = protos[w] + (self.hparams.lr * shift) + self.proto_layer.load_state_dict({"_components": updated_protos}, + strict=False) + + # Logging + self.log_acc(dis, y, tag="train_acc") + + return None + + +class LVQ21(NonGradientLVQ): + """Learning Vector Quantization 2.1.""" + def training_step(self, train_batch, batch_idx, optimizer_idx=None): + protos = self.proto_layer.components + plabels = self.proto_layer.component_labels + + x, y = train_batch + dis = self._forward(x) + # TODO Vectorized implementation + + for xi, yi in zip(x, y): + xi = xi.view(1, -1) + yi = yi.view(1, ) + d = self._forward(xi) + (_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True) + shiftp = xi - protos[wp] + shiftn = protos[wn] - xi + updated_protos = protos + 0.0 + updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp) + updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn) + self.proto_layer.load_state_dict({"_components": updated_protos}, + strict=False) + + # Logging + self.log_acc(dis, y, tag="train_acc") + + return None + + +class MedianLVQ(NonGradientLVQ): + """Median LVQ""" diff --git a/prototorch/models/probabilistic.py b/prototorch/models/probabilistic.py index da47c64..62e2e62 100644 --- a/prototorch/models/probabilistic.py +++ b/prototorch/models/probabilistic.py @@ -6,7 +6,7 @@ from .glvq import GLVQ # HELPER -# TODO: Refactor into general files, if usefull +# TODO: Refactor into general files, if useful def probability(distance, variance): return torch.exp(-(distance * distance) / (2 * variance)) @@ -14,30 +14,30 @@ def probability(distance, variance): def grouped_sum(value: torch.Tensor, labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor): """Group-wise average for (sparse) grouped tensors - + Args: value (torch.Tensor): values to average (# samples, latent dimension) labels (torch.LongTensor): labels for embedding parameters (# samples,) - - Returns: + + Returns: result (torch.Tensor): (# unique labels, latent dimension) new_labels (torch.LongTensor): (# unique labels,) - + Examples: >>> samples = torch.Tensor([ [0.15, 0.15, 0.15], #-> group / class 1 - [0.2, 0.2, 0.2], #-> group / class 3 - [0.4, 0.4, 0.4], #-> group / class 3 - [0.0, 0.0, 0.0] #-> group / class 0 + [0.2, 0.2, 0.2 ], #-> group / class 3 + [0.4, 0.4, 0.4 ], #-> group / class 3 + [0.0, 0.0, 0.0 ] #-> group / class 0 ]) >>> labels = torch.LongTensor([1, 5, 5, 0]) >>> result, new_labels = groupby_mean(samples, labels) - + >>> result tensor([[0.0000, 0.0000, 0.0000], - [0.1500, 0.1500, 0.1500], - [0.3000, 0.3000, 0.3000]]) - + [0.1500, 0.1500, 0.1500], + [0.3000, 0.3000, 0.3000]]) + >>> new_labels tensor([0, 1, 5]) """