[REFACTOR] Major cleanup

This commit is contained in:
Jensun Ravichandran 2021-06-04 22:20:32 +02:00
parent 20471bfb1c
commit 016fcb4060
11 changed files with 481 additions and 399 deletions

View File

@ -4,11 +4,23 @@ from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, GrowingNeuralGas, NeuralGas
from .unsupervised import GrowingNeuralGas, NeuralGas
from .vis import *
__version__ = "0.1.7"

View File

@ -1,7 +1,39 @@
"""Abstract classes to be inherited by prototorch models."""
from typing import Final, final
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import WTAC, LambdaLayer
class AbstractPrototypeModel(pl.LightningModule):
class ProtoTorchBolt(pl.LightningModule):
def __repr__(self):
super_repr = super().__repr__()
return f"ProtoTorch Bolt:\n{super_repr}"
class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("lr", 0.01)
# Default config
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@ -28,9 +60,115 @@ class AbstractPrototypeModel(pl.LightningModule):
else:
return optimizer
@final
def reconfigure_optimizers(self):
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
class PrototypeImageModel(pl.LightningModule):
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
prototype_initializer = kwargs.get("prototype_initializer", None)
if prototype_initializer is not None:
self.proto_layer = Components(
self.hparams.num_prototypes,
initializer=prototype_initializer,
)
def compute_distances(self, x):
protos = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
return distances
class SupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
prototype_initializer = kwargs.get("prototype_initializer", None)
if prototype_initializer is not None:
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
initializer=prototype_initializer,
)
self.competition_layer = WTAC()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@property
def num_classes(self):
return len(self.proto_layer.distribution)
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
# TODO
y_pred = torch.eye(self.num_classes, device=self.device)[
y_pred.long()] # depends on labels {0,...,num_classes}
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.component_labels
y_pred = self.competition_layer(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
return y_pred
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
class NonGradientMixin():
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization: Final = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchBolt):
"""Mixin for models with image prototypes."""
@final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):

View File

@ -1,7 +1,12 @@
"""Lightning Callbacks."""
import logging
import pytorch_lightning as pl
import torch
from prototorch.components import Components
from .extras import ConnectionTopology
class PruneLoserPrototypes(pl.Callback):
@ -26,26 +31,30 @@ class PruneLoserPrototypes(pl.Callback):
return None
if (trainer.current_epoch + 1) % self.frequency:
return None
ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
prune_labels = pl_module.prototype_labels[to_prune.tolist()]
to_prune = to_prune.tolist()
prune_labels = pl_module.prototype_labels[to_prune]
if self.prune_quota_per_epoch > 0:
to_prune = to_prune[:self.prune_quota_per_epoch]
prune_labels = prune_labels[:self.prune_quota_per_epoch]
if len(to_prune) > 0:
if self.verbose:
print(f"\nPrototype win ratios: {ratios}")
print(f"Pruning prototypes at: {to_prune.tolist()}")
print(f"Pruning prototypes at: {to_prune}")
print(f"Corresponding labels are: {prune_labels}")
cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune)
if self.replace:
if self.verbose:
print(f"Re-adding prototypes at: {to_prune.tolist()}")
labels, counts = torch.unique(prune_labels,
sorted=True,
return_counts=True)
distribution = dict(zip(labels.tolist(), counts.tolist()))
print(f"{distribution=}")
if self.verbose:
print(f"Re-adding pruned prototypes...")
print(f"{distribution=}")
pl_module.add_prototypes(distribution=distribution,
initializer=self.initializer)
new_num_protos = pl_module.num_prototypes
@ -68,3 +77,58 @@ class PrototypeConvergence(pl.Callback):
print("Stopping...")
# TODO
return True
class GNGCallback(pl.Callback):
"""GNG Callback.
Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components: Components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
neighbors = topology.get_neighbors(worst)[0]
if len(neighbors) == 0:
logging.log(level=20, msg="No neighbor-pairs found!")
return
neighbors_errors = errors[neighbors]
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbor])
# Add component
pl_module.proto_layer.add_components(
initialized_components=new_component.unsqueeze(0))
# Adjust Topology
topology.add_prototype()
topology.add_connection(worst, -1)
topology.add_connection(worst_neighbor, -1)
topology.remove_connection(worst, worst_neighbor)
# New errors
worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)

View File

@ -1,86 +1,18 @@
import torch
import torchmetrics
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from .abstract import ImagePrototypesMixin
from .extras import (
CosineSimilarity,
MarginLoss,
ReasoningLayer,
euclidean_similarity,
rescaled_cosine_similarity,
shift_activation,
)
from .glvq import SiameseGLVQ
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
similarities = cosine_similarity(x, y)
return (similarities + 1.0) / 2.0
def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y, variance=1.0):
d = euclidean_distance(x, y)
return torch.exp(-(d * d) / (2 * variance))
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, num_components, num_classes, num_replicas=1):
super().__init__()
self.num_replicas = num_replicas
self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6)
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs
class CBC(SiameseGLVQ):
"""Classification-By-Components."""
def __init__(self,
@ -143,10 +75,11 @@ class CBC(SiameseGLVQ):
return y_pred
class ImageCBC(CBC):
class ImageCBC(ImagePrototypesMixin, CBC):
"""CBC model that constrains the components to the range [0, 1] by
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
self.component_layer.components.data.clamp_(0.0, 1.0)
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Namespace hook
self.proto_layer = self.component_layer

142
prototorch/models/extras.py Normal file
View File

@ -0,0 +1,142 @@
"""prototorch.models.extras
Modules not yet available in prototorch go here temporarily.
"""
import torch
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
similarities = cosine_similarity(x, y)
return (similarities + 1.0) / 2.0
def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y, variance=1.0):
d = euclidean_distance(x, y)
return torch.exp(-(d * d) / (2 * variance))
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit
self.num_prototypes = num_prototypes
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
self.age = torch.zeros_like(self.cmat)
def forward(self, d):
order = torch.argsort(d, dim=1)
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.cmat[i1][i0] = 1
self.age[i0][i1] = 0
self.age[i1][i0] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.age[i1][self.cmat[i1] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbors(self, position):
return torch.where(self.cmat[position])
def add_prototype(self):
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
new_cmat[:-1, :-1] = self.cmat
self.cmat = new_cmat
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
new_age[:-1, :-1] = self.age
self.age = new_age
def add_connection(self, a, b):
self.cmat[a][b] = 1
self.cmat[b][a] = 1
self.age[a][b] = 0
self.age[b][a] = 0
def remove_connection(self, a, b):
self.cmat[a][b] = 0
self.cmat[b][a] = 0
self.age[a][b] = 0
self.age[b][a] = 0
def extra_repr(self):
return f"(agelimit): ({self.agelimit})"
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, num_components, num_classes, num_replicas=1):
super().__init__()
self.num_replicas = num_replicas
self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6)
# TODO Use `self.register_parameter("param", Paramater(param))` instead
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs

View File

@ -1,101 +1,40 @@
"""Models based on the GLVQ framework."""
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer
from prototorch.modules import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from .abstract import AbstractPrototypeModel, PrototypeImageModel
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
class GLVQ(AbstractPrototypeModel):
class GLVQ(SupervisedPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Defaults
# Default hparams
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
transfer_fn = get_activation(self.hparams.transfer_fn)
# Layers
prototype_initializer = kwargs.get("prototype_initializer", None)
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
initializer=prototype_initializer)
self.distance_layer = LambdaLayer(distance_fn)
transfer_fn = get_activation(self.hparams.transfer_fn)
self.transfer_layer = LambdaLayer(transfer_fn)
self.loss = LambdaLayer(glvq_loss)
# Loss
self.loss = LossLayer(glvq_loss)
# Prototype metrics
self.initialize_prototype_win_ratios()
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@property
def num_classes(self):
return len(self.proto_layer.distribution)
def _forward(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.component_labels
y_pred = wtac(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
return y_pred
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
@ -121,7 +60,7 @@ class GLVQ(AbstractPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x)
out = self.compute_distances(x)
plabels = self.proto_layer.component_labels
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
@ -158,18 +97,6 @@ class GLVQ(AbstractPrototypeModel):
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
def add_prototypes(self, initializer, distribution):
self.proto_layer.add_components(initializer, distribution)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self):
super_repr = super().__repr__()
return f"{super_repr}"
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
@ -212,7 +139,7 @@ class SiameseGLVQ(GLVQ):
else:
return optimizer
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
@ -256,7 +183,7 @@ class GRLVQ(SiameseGLVQ):
def relevance_profile(self):
return self.relevances.detach().cpu()
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
return distances
@ -285,7 +212,7 @@ class SiameseGMLVQ(SiameseGLVQ):
lam = omega.T @ omega
return lam.detach().cpu()
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.backbone(x)
@ -305,7 +232,7 @@ class LVQMLN(SiameseGLVQ):
rather in the embedding space.
"""
def _forward(self, x):
def compute_distances(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
distances = self.distance_layer(latent_x, latent_protos)
@ -327,7 +254,7 @@ class GMLVQ(GLVQ):
device=self.device)
self.register_parameter("_omega", Parameter(omega))
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
return distances
@ -355,7 +282,7 @@ class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
self.loss = LossLayer(lvq1_loss)
self.optimizer = torch.optim.SGD
@ -363,11 +290,11 @@ class GLVQ21(GLVQ):
"""Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
self.loss = LossLayer(lvq21_loss)
self.optimizer = torch.optim.SGD
class ImageGLVQ(PrototypeImageModel, GLVQ):
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
@ -376,7 +303,7 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
"""
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
"""GMLVQ for training on image data.
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping

38
prototorch/models/knn.py Normal file
View File

@ -0,0 +1,38 @@
"""ProtoTorch KNN model."""
import warnings
from prototorch.components import LabeledComponents
from prototorch.modules import KNNC
from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("k", 1)
data = kwargs.get("data", None)
if data is None:
raise ValueError("KNN requires data, but was not provided!")
# Layers
self.proto_layer = LabeledComponents(initialized_components=data)
self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None

View File

@ -1,34 +1,24 @@
"""LVQ models that are optimized using non-gradient methods."""
from prototorch.functions.competitions import wtac
from prototorch.functions.losses import _get_dp_dm
from .abstract import NonGradientMixin
from .glvq import GLVQ
class NonGradientLVQ(GLVQ):
"""Abstract Model for Models that do not use gradients in their update phase."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class LVQ1(NonGradientLVQ):
class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
dis = self.compute_distances(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
d = self._forward(xi.view(1, -1))
preds = wtac(d, plabels)
d = self.compute_distances(xi.view(1, -1))
preds = self.competition_layer(d, plabels)
w = d.argmin(1)
if yi == preds:
shift = xi - protos[w]
@ -45,20 +35,20 @@ class LVQ1(NonGradientLVQ):
return None
class LVQ21(NonGradientLVQ):
class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
dis = self.compute_distances(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
xi = xi.view(1, -1)
yi = yi.view(1, )
d = self._forward(xi)
d = self.compute_distances(xi)
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
@ -74,5 +64,5 @@ class LVQ21(NonGradientLVQ):
return None
class MedianLVQ(NonGradientLVQ):
class MedianLVQ(NonGradientMixin, GLVQ):
"""Median LVQ"""

View File

@ -1,10 +1,10 @@
"""Probabilistic GLVQ methods"""
import torch
from prototorch.functions.competitions import stratified_min, stratified_sum
from prototorch.functions.losses import (log_likelihood_ratio_loss,
robust_soft_loss)
from prototorch.functions.losses import nllr_loss, rslvq_loss
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
from prototorch.functions.transforms import gaussian
from prototorch.modules import LambdaLayer, LossLayer
from .glvq import GLVQ
@ -13,13 +13,16 @@ class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Loss
self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x) # [None, num_protos]
out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.component_labels
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long())
loss = batch_loss.sum(dim=0)
return out, loss
@ -33,14 +36,14 @@ class ProbabilisticLVQ(GLVQ):
self.rejection_confidence = rejection_confidence
def forward(self, x):
distances = self._forward(x)
distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances,
self.hparams.variance)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
plabels = self.proto_layer._labels
y_pred = stratified_sum(posterior, plabels)
y_pred = stratified_sum_pooling(posterior, plabels)
return y_pred
def predict(self, x):
@ -50,12 +53,11 @@ class ProbabilisticLVQ(GLVQ):
return prediction
def training_step(self, batch, batch_idx, optimizer_idx=None):
X, y = batch
out = self.forward(X)
x, y = batch
out = self.forward(x)
plabels = self.proto_layer.component_labels
batch_loss = self.loss_fn(out, y, plabels)
batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum(dim=0)
return loss
@ -63,11 +65,11 @@ class LikelihoodRatioLVQ(ProbabilisticLVQ):
"""Learning Vector Quantization based on Likelihood Ratios."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_fn = log_likelihood_ratio_loss
self.loss = LossLayer(nllr_loss)
class RSLVQ(ProbabilisticLVQ):
"""Robust Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_fn = robust_soft_loss
self.loss = LossLayer(rslvq_loss)

View File

@ -8,195 +8,30 @@ import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
from prototorch.components.initializers import ZerosInitializer
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import LambdaLayer
from prototorch.modules.losses import NeuralGasEnergy
from pytorch_lightning.callbacks import Callback
from .abstract import AbstractPrototypeModel
from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
class GNGCallback(Callback):
"""GNG Callback.
Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components: pt.components.Components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
neighbors = topology.get_neighbors(worst)[0]
if len(neighbors) == 0:
logging.log(level=20, msg="No neighbor-pairs found!")
return
neighbors_errors = errors[neighbors]
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbor])
# Add component
pl_module.proto_layer.add_components(
initialized_components=new_component.unsqueeze(0))
# Adjust Topology
topology.add_prototype()
topology.add_connection(worst, -1)
topology.add_connection(worst_neighbor, -1)
topology.remove_connection(worst, worst_neighbor)
# New errors
worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit
self.num_prototypes = num_prototypes
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
self.age = torch.zeros_like(self.cmat)
def forward(self, d):
order = torch.argsort(d, dim=1)
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.cmat[i1][i0] = 1
self.age[i0][i1] = 0
self.age[i1][i0] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.age[i1][self.cmat[i1] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbors(self, position):
return torch.where(self.cmat[position])
def add_prototype(self):
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
new_cmat[:-1, :-1] = self.cmat
self.cmat = new_cmat
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
new_age[:-1, :-1] = self.age
self.age = new_age
def add_connection(self, a, b):
self.cmat[a][b] = 1
self.cmat[b][a] = 1
self.age[a][b] = 0
self.age[b][a] = 0
def remove_connection(self, a, b):
self.cmat[a][b] = 0
self.cmat[b][a] = 0
self.age[a][b] = 0
self.age[b][a] = 0
def extra_repr(self):
return f"(agelimit): ({self.agelimit})"
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
class NeuralGas(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__()
super().__init__(hparams, **kwargs)
# Hyperparameters
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None
class NeuralGas(AbstractPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
# Default Values
# Default hparams
self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1)
self.proto_layer = Components(
self.hparams.num_prototypes,
initializer=kwargs.get("prototype_initializer"))
self.distance_layer = LambdaLayer(euclidean_distance)
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit,
@ -204,9 +39,10 @@ class NeuralGas(AbstractPrototypeModel):
)
def training_step(self, train_batch, batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
protos = self.proto_layer()
d = self.distance_layer(x, protos)
d = self.compute_distances(x)
cost, _ = self.energy_layer(d)
self.topology_layer(d)
return cost
@ -216,26 +52,26 @@ class GrowingNeuralGas(NeuralGas):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# defaults
# Defaults
self.hparams.setdefault("step_reduction", 0.5)
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
self.register_buffer(
"errors",
torch.zeros(self.hparams.num_prototypes, device=self.device))
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
self.register_buffer("errors", errors)
def training_step(self, train_batch, _batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
protos = self.proto_layer()
d = self.distance_layer(x, protos)
d = self.compute_distances(x)
cost, order = self.energy_layer(d)
winner = order[:, 0]
mask = torch.zeros_like(d)
mask[torch.arange(len(mask)), winner] = 1.0
winner_distances = d * mask
dp = d * mask
self.errors += torch.sum(winner_distances * winner_distances, dim=0)
self.errors += torch.sum(dp * dp, dim=0)
self.errors *= self.hparams.step_reduction
self.topology_layer(d)

View File

@ -140,7 +140,7 @@ class VisGLVQ2D(Vis2DAbstract):
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
_components = pl_module.proto_layer._components
mesh_input = torch.Tensor(mesh_input).type_as(_components)
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)