Update Documentation

Clean up project
This commit is contained in:
Alexander Engelsberger
2021-05-21 15:42:45 +02:00
parent a5e086ce0d
commit 7b4f7d84e0
11 changed files with 146 additions and 126 deletions

View File

@@ -3,8 +3,7 @@ from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ)
from .knn import KNN
from .neural_gas import NeuralGas
from .unsupervised import KNN, NeuralGas
from .vis import *
__version__ = "0.1.7"

View File

@@ -1,14 +0,0 @@
"""Callbacks for Pytorch Lighning Modules"""
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")

View File

@@ -1,3 +1,4 @@
"""Models based on the GLVQ Framework"""
import torch
import torchmetrics
from prototorch.components import LabeledComponents
@@ -6,15 +7,8 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
lvq1_loss, lvq21_loss)
from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
from prototorch.functions.losses import (_get_dp_dm, glvq_loss, lvq1_loss,
lvq21_loss)
from .abstract import AbstractPrototypeModel, PrototypeImageModel
@@ -192,11 +186,14 @@ class GRLVQ(SiameseGLVQ):
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
# Overwrite backbone
self.backbone = self._backbone
@property
def relevance_profile(self):
return self.relevances.detach().cpu()
def backbone(self, x):
def _backbone(self, x):
"""Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances)
@@ -262,6 +259,7 @@ class LVQMLN(SiameseGLVQ):
class NonGradientGLVQ(GLVQ):
"""Abstract Model for Models that do not use gradients in their update phase."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
@@ -271,6 +269,7 @@ class NonGradientGLVQ(GLVQ):
class LVQ1(NonGradientGLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
@@ -299,6 +298,7 @@ class LVQ1(NonGradientGLVQ):
class LVQ21(NonGradientGLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
@@ -311,8 +311,7 @@ class LVQ21(NonGradientGLVQ):
xi = xi.view(1, -1)
yi = yi.view(1, )
d = self(xi)
preds = wtac(d, plabels)
(dp, wp), (dn, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
@@ -328,11 +327,11 @@ class LVQ21(NonGradientGLVQ):
class MedianLVQ(NonGradientGLVQ):
...
"""Median LVQ"""
class GLVQ1(GLVQ):
"""Learning Vector Quantization 1."""
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
@@ -340,7 +339,7 @@ class GLVQ1(GLVQ):
class GLVQ21(GLVQ):
"""Learning Vector Quantization 2.1."""
"""Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
@@ -354,7 +353,6 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
after updates.
"""
pass
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
@@ -364,4 +362,3 @@ class ImageGMLVQ(PrototypeImageModel, GMLVQ):
after updates.
"""
pass

View File

@@ -1,62 +0,0 @@
"""The popular K-Nearest-Neighbors classification algorithm."""
import warnings
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.components.initializers import parse_data_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from .abstract import AbstractPrototypeModel
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None

View File

@@ -1,7 +1,13 @@
"""Unsupervised prototype learning algorithms."""
import warnings
import torch
from prototorch.components import Components
import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.components import initializers as cinit
from prototorch.components.initializers import ZerosInitializer
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import NeuralGasEnergy
@@ -36,6 +42,56 @@ class ConnectionTopology(torch.nn.Module):
return f"agelimit: {self.agelimit}"
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None
class NeuralGas(AbstractPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__()