Custom non-gradient training

This commit is contained in:
Jensun Ravichandran 2021-05-18 19:49:16 +02:00
parent 246719b837
commit eefec19c9b
3 changed files with 101 additions and 18 deletions

View File

@ -72,6 +72,8 @@ git checkout dev
pip install -e .[all] # \[all\] if you are using zsh or MacOS pip install -e .[all] # \[all\] if you are using zsh or MacOS
``` ```
**Note: Please avoid installing Tensorflow in this environment.**
To assist in the development process, you may also find it useful to install To assist in the development process, you may also find it useful to install
`yapf`, `isort` and `autoflake`. You can install them easily with `pip`. `yapf`, `isort` and `autoflake`. You can install them easily with `pip`.

View File

@ -1,8 +1,8 @@
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC from .cbc import CBC
from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ, from .glvq import (GLVQ, GMLVQ, GRLVQ, GLVQ1, GLVQ21, LVQ1, LVQ21, LVQMLN,
ImageGMLVQ, SiameseGLVQ) ImageGLVQ, ImageGMLVQ, SiameseGLVQ)
from .knn import KNN from .knn import KNN
from .neural_gas import NeuralGas from .neural_gas import NeuralGas
from .vis import * from .vis import *

View File

@ -6,7 +6,8 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance, from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed) sed)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
lvq1_loss, lvq21_loss)
from .abstract import (AbstractPrototypeModel, PrototypeImageModel, from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
SiamesePrototypeModel) SiamesePrototypeModel)
@ -33,6 +34,7 @@ class GLVQ(AbstractPrototypeModel):
# Default Values # Default Values
self.hparams.setdefault("transfer_function", "identity") self.hparams.setdefault("transfer_function", "identity")
self.hparams.setdefault("transfer_beta", 10.0) self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01)
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
@ -52,6 +54,23 @@ class GLVQ(AbstractPrototypeModel):
dis = self.distance_fn(x, protos) dis = self.distance_fn(x, protos)
return dis return dis
def log_acc(self, distances, targets):
plabels = self.proto_layer.component_labels
# Compute training accuracy
with torch.no_grad():
preds = wtac(distances, plabels)
self.train_acc(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log("acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
x, y = train_batch x, y = train_batch
dis = self(x) dis = self(x)
@ -61,21 +80,9 @@ class GLVQ(AbstractPrototypeModel):
beta=self.hparams.transfer_beta) beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0) loss = batch_loss.sum(dim=0)
# Compute training accuracy
with torch.no_grad():
preds = wtac(dis, plabels)
self.train_acc(preds.int(), y.int())
# `.int()` because FloatTensors are assumed to be class probabilities
# Logging # Logging
self.log("train_loss", loss) self.log("train_loss", loss)
self.log("acc", self.log_acc(dis, y)
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
return loss return loss
@ -87,6 +94,10 @@ class GLVQ(AbstractPrototypeModel):
y_pred = wtac(d, plabels) y_pred = wtac(d, plabels)
return y_pred return y_pred
def __repr__(self):
super_repr = super().__repr__()
return f"{super_repr}"
class SiameseGLVQ(SiamesePrototypeModel, GLVQ): class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
"""GLVQ in a Siamese setting. """GLVQ in a Siamese setting.
@ -198,7 +209,77 @@ class LVQMLN(SiamesePrototypeModel, GLVQ):
return dis return dis
class LVQ1(GLVQ): class NonGradientGLVQ(GLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class LVQ1(NonGradientGLVQ):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
d = self(xi.view(1, -1))
preds = wtac(d, plabels)
w = d.argmin(1)
if yi == preds:
shift = xi - protos[w]
else:
shift = protos[w] - xi
updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams.lr * shift)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
self.log_acc(dis, y)
return None
class LVQ21(NonGradientGLVQ):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
xi = xi.view(1, -1)
yi = yi.view(1, )
d = self(xi)
preds = wtac(d, plabels)
(dp, wp), (dn, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
self.log_acc(dis, y)
return None
class MedianLVQ(NonGradientGLVQ):
...
class GLVQ1(GLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@ -206,7 +287,7 @@ class LVQ1(GLVQ):
self.optimizer = torch.optim.SGD self.optimizer = torch.optim.SGD
class LVQ21(GLVQ): class GLVQ21(GLVQ):
"""Learning Vector Quantization 2.1.""" """Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)