feat(model): implement MedianLVQ

This commit is contained in:
Jensun Ravichandran 2021-07-06 17:12:51 +02:00
parent 9d38123114
commit 4be9fb81eb
No known key found for this signature in database
GPG Key ID: 3331B0F18B6D4D93
3 changed files with 113 additions and 2 deletions

View File

@ -36,6 +36,7 @@ be available for use in your Python environment as `prototorch.models`.
- Soft Learning Vector Quantization (SLVQ) - Soft Learning Vector Quantization (SLVQ)
- Robust Soft Learning Vector Quantization (RSLVQ) - Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ) - Probabilistic Learning Vector Quantization (PLVQ)
- Median-LVQ
### Other ### Other
@ -51,7 +52,6 @@ be available for use in your Python environment as `prototorch.models`.
## Planned models ## Planned models
- Median-LVQ
- Generalized Tangent Learning Vector Quantization (GTLVQ) - Generalized Tangent Learning Vector Quantization (GTLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ) - Self-Incremental Learning Vector Quantization (SILVQ)

View File

@ -0,0 +1,52 @@
"""Median-LVQ example using the Iris dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
)
# Initialize the model
model = pt.models.MedianLVQ(
hparams=dict(distribution=(3, 2), lr=0.01),
prototypes_initializer=pt.initializers.SSCI(train_ds),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.01,
patience=5,
mode="max",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis, es],
weights_summary="full",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -1,6 +1,8 @@
"""LVQ models that are optimized using non-gradient methods.""" """LVQ models that are optimized using non-gradient methods."""
from ..core.losses import _get_dp_dm from ..core.losses import _get_dp_dm
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin from .abstract import NonGradientMixin
from .glvq import GLVQ from .glvq import GLVQ
@ -66,4 +68,61 @@ class LVQ21(NonGradientMixin, GLVQ):
class MedianLVQ(NonGradientMixin, GLVQ): class MedianLVQ(NonGradientMixin, GLVQ):
"""Median LVQ""" """Median LVQ
# TODO Avoid computing distances over and over
"""
def __init__(self, hparams, verbose=True, **kwargs):
self.verbose = verbose
super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer(
get_activation(self.hparams.transfer_fn))
def _f(self, x, y, protos, plabels):
d = self.distance_layer(x, protos)
dp, dm = _get_dp_dm(d, y, plabels)
mu = (dp - dm) / (dp + dm)
invmu = -1.0 * mu
f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0
return f
def expectation(self, x, y, protos, plabels):
f = self._f(x, y, protos, plabels)
gamma = f / f.sum()
return gamma
def lower_bound(self, x, y, protos, plabels, gamma):
f = self._f(x, y, protos, plabels)
lower_bound = (gamma * f.log()).sum()
return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
x, y = train_batch
dis = self.compute_distances(x)
for i, _ in enumerate(protos):
# Expectation step
gamma = self.expectation(x, y, protos, plabels)
lower_bound = self.lower_bound(x, y, protos, plabels, gamma)
# Maximization step
_protos = protos + 0
for k, xk in enumerate(x):
_protos[i] = xk
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound:
if self.verbose:
print(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict({"_components": _protos},
strict=False)
break
# Logging
self.log_acc(dis, y, tag="train_acc")
return None