[REFACTOR] Major cleanup

This commit is contained in:
Jensun Ravichandran 2021-06-04 22:20:32 +02:00
parent 20471bfb1c
commit 016fcb4060
11 changed files with 481 additions and 399 deletions

View File

@ -4,11 +4,23 @@ from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN, from .glvq import (
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, GrowingNeuralGas, NeuralGas from .unsupervised import GrowingNeuralGas, NeuralGas
from .vis import * from .vis import *
__version__ = "0.1.7" __version__ = "0.1.7"

View File

@ -1,7 +1,39 @@
"""Abstract classes to be inherited by prototorch models."""
from typing import Final, final
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import WTAC, LambdaLayer
class AbstractPrototypeModel(pl.LightningModule): class ProtoTorchBolt(pl.LightningModule):
def __repr__(self):
super_repr = super().__repr__()
return f"ProtoTorch Bolt:\n{super_repr}"
class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("lr", 0.01)
# Default config
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property @property
def num_prototypes(self): def num_prototypes(self):
return len(self.proto_layer.components) return len(self.proto_layer.components)
@ -28,9 +60,115 @@ class AbstractPrototypeModel(pl.LightningModule):
else: else:
return optimizer return optimizer
@final
def reconfigure_optimizers(self):
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
class PrototypeImageModel(pl.LightningModule): def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
prototype_initializer = kwargs.get("prototype_initializer", None)
if prototype_initializer is not None:
self.proto_layer = Components(
self.hparams.num_prototypes,
initializer=prototype_initializer,
)
def compute_distances(self, x):
protos = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
return distances
class SupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
prototype_initializer = kwargs.get("prototype_initializer", None)
if prototype_initializer is not None:
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
initializer=prototype_initializer,
)
self.competition_layer = WTAC()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@property
def num_classes(self):
return len(self.proto_layer.distribution)
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
# TODO
y_pred = torch.eye(self.num_classes, device=self.device)[
y_pred.long()] # depends on labels {0,...,num_classes}
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.component_labels
y_pred = self.competition_layer(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
return y_pred
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
class NonGradientMixin():
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization: Final = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchBolt):
"""Mixin for models with image prototypes."""
@final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0) self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True): def get_prototype_grid(self, num_columns=2, return_channels_last=True):

View File

@ -1,7 +1,12 @@
"""Lightning Callbacks.""" """Lightning Callbacks."""
import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.components import Components
from .extras import ConnectionTopology
class PruneLoserPrototypes(pl.Callback): class PruneLoserPrototypes(pl.Callback):
@ -26,26 +31,30 @@ class PruneLoserPrototypes(pl.Callback):
return None return None
if (trainer.current_epoch + 1) % self.frequency: if (trainer.current_epoch + 1) % self.frequency:
return None return None
ratios = pl_module.prototype_win_ratios.mean(dim=0) ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune = torch.arange(len(ratios))[ratios < self.threshold] to_prune = torch.arange(len(ratios))[ratios < self.threshold]
prune_labels = pl_module.prototype_labels[to_prune.tolist()] to_prune = to_prune.tolist()
prune_labels = pl_module.prototype_labels[to_prune]
if self.prune_quota_per_epoch > 0: if self.prune_quota_per_epoch > 0:
to_prune = to_prune[:self.prune_quota_per_epoch] to_prune = to_prune[:self.prune_quota_per_epoch]
prune_labels = prune_labels[:self.prune_quota_per_epoch] prune_labels = prune_labels[:self.prune_quota_per_epoch]
if len(to_prune) > 0: if len(to_prune) > 0:
if self.verbose: if self.verbose:
print(f"\nPrototype win ratios: {ratios}") print(f"\nPrototype win ratios: {ratios}")
print(f"Pruning prototypes at: {to_prune.tolist()}") print(f"Pruning prototypes at: {to_prune}")
print(f"Corresponding labels are: {prune_labels}")
cur_num_protos = pl_module.num_prototypes cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune) pl_module.remove_prototypes(indices=to_prune)
if self.replace: if self.replace:
if self.verbose:
print(f"Re-adding prototypes at: {to_prune.tolist()}")
labels, counts = torch.unique(prune_labels, labels, counts = torch.unique(prune_labels,
sorted=True, sorted=True,
return_counts=True) return_counts=True)
distribution = dict(zip(labels.tolist(), counts.tolist())) distribution = dict(zip(labels.tolist(), counts.tolist()))
print(f"{distribution=}") if self.verbose:
print(f"Re-adding pruned prototypes...")
print(f"{distribution=}")
pl_module.add_prototypes(distribution=distribution, pl_module.add_prototypes(distribution=distribution,
initializer=self.initializer) initializer=self.initializer)
new_num_protos = pl_module.num_prototypes new_num_protos = pl_module.num_prototypes
@ -68,3 +77,58 @@ class PrototypeConvergence(pl.Callback):
print("Stopping...") print("Stopping...")
# TODO # TODO
return True return True
class GNGCallback(pl.Callback):
"""GNG Callback.
Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components: Components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
neighbors = topology.get_neighbors(worst)[0]
if len(neighbors) == 0:
logging.log(level=20, msg="No neighbor-pairs found!")
return
neighbors_errors = errors[neighbors]
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbor])
# Add component
pl_module.proto_layer.add_components(
initialized_components=new_component.unsqueeze(0))
# Adjust Topology
topology.add_prototype()
topology.add_connection(worst, -1)
topology.add_connection(worst_neighbor, -1)
topology.remove_connection(worst, worst_neighbor)
# New errors
worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)

View File

@ -1,86 +1,18 @@
import torch import torch
import torchmetrics import torchmetrics
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from .abstract import ImagePrototypesMixin
from .extras import (
CosineSimilarity,
MarginLoss,
ReasoningLayer,
euclidean_similarity,
rescaled_cosine_similarity,
shift_activation,
)
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
similarities = cosine_similarity(x, y)
return (similarities + 1.0) / 2.0
def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y, variance=1.0):
d = euclidean_distance(x, y)
return torch.exp(-(d * d) / (2 * variance))
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, num_components, num_classes, num_replicas=1):
super().__init__()
self.num_replicas = num_replicas
self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6)
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs
class CBC(SiameseGLVQ): class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
def __init__(self, def __init__(self,
@ -143,10 +75,11 @@ class CBC(SiameseGLVQ):
return y_pred return y_pred
class ImageCBC(CBC): class ImageCBC(ImagePrototypesMixin, CBC):
"""CBC model that constrains the components to the range [0, 1] by """CBC model that constrains the components to the range [0, 1] by
clamping after updates. clamping after updates.
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def __init__(self, hparams, **kwargs):
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) super().__init__(hparams, **kwargs)
self.component_layer.components.data.clamp_(0.0, 1.0) # Namespace hook
self.proto_layer = self.component_layer

142
prototorch/models/extras.py Normal file
View File

@ -0,0 +1,142 @@
"""prototorch.models.extras
Modules not yet available in prototorch go here temporarily.
"""
import torch
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
similarities = cosine_similarity(x, y)
return (similarities + 1.0) / 2.0
def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y, variance=1.0):
d = euclidean_distance(x, y)
return torch.exp(-(d * d) / (2 * variance))
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit
self.num_prototypes = num_prototypes
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
self.age = torch.zeros_like(self.cmat)
def forward(self, d):
order = torch.argsort(d, dim=1)
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.cmat[i1][i0] = 1
self.age[i0][i1] = 0
self.age[i1][i0] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.age[i1][self.cmat[i1] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbors(self, position):
return torch.where(self.cmat[position])
def add_prototype(self):
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
new_cmat[:-1, :-1] = self.cmat
self.cmat = new_cmat
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
new_age[:-1, :-1] = self.age
self.age = new_age
def add_connection(self, a, b):
self.cmat[a][b] = 1
self.cmat[b][a] = 1
self.age[a][b] = 0
self.age[b][a] = 0
def remove_connection(self, a, b):
self.cmat[a][b] = 0
self.cmat[b][a] = 0
self.age[a][b] = 0
self.age[b][a] = 0
def extra_repr(self):
return f"(agelimit): ({self.agelimit})"
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, num_components, num_classes, num_replicas=1):
super().__init__()
self.num_replicas = num_replicas
self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6)
# TODO Use `self.register_parameter("param", Paramater(param))` instead
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs

View File

@ -1,101 +1,40 @@
"""Models based on the GLVQ framework.""" """Models based on the GLVQ framework."""
import torch import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import ( from prototorch.functions.distances import (
euclidean_distance,
lomega_distance, lomega_distance,
omega_distance, omega_distance,
squared_euclidean_distance, squared_euclidean_distance,
) )
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer from prototorch.modules import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .abstract import AbstractPrototypeModel, PrototypeImageModel from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
class GLVQ(AbstractPrototypeModel): class GLVQ(SupervisedPrototypeModel):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__() # Default hparams
# Hyperparameters
self.save_hyperparameters(hparams)
# Defaults
self.hparams.setdefault("transfer_fn", "identity") self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0) self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
transfer_fn = get_activation(self.hparams.transfer_fn)
# Layers # Layers
prototype_initializer = kwargs.get("prototype_initializer", None) transfer_fn = get_activation(self.hparams.transfer_fn)
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
initializer=prototype_initializer)
self.distance_layer = LambdaLayer(distance_fn)
self.transfer_layer = LambdaLayer(transfer_fn) self.transfer_layer = LambdaLayer(transfer_fn)
self.loss = LambdaLayer(glvq_loss)
# Loss
self.loss = LossLayer(glvq_loss)
# Prototype metrics # Prototype metrics
self.initialize_prototype_win_ratios() self.initialize_prototype_win_ratios()
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@property
def num_classes(self):
return len(self.proto_layer.distribution)
def _forward(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.component_labels
y_pred = wtac(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
return y_pred
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios",
@ -121,7 +60,7 @@ class GLVQ(AbstractPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self._forward(x) out = self.compute_distances(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
mu = self.loss(out, y, prototype_labels=plabels) mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta) batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
@ -158,18 +97,6 @@ class GLVQ(AbstractPrototypeModel):
# def predict_step(self, batch, batch_idx, dataloader_idx=None): # def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass # pass
def add_prototypes(self, initializer, distribution):
self.proto_layer.add_components(initializer, distribution)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self):
super_repr = super().__repr__()
return f"{super_repr}"
class SiameseGLVQ(GLVQ): class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting. """GLVQ in a Siamese setting.
@ -212,7 +139,7 @@ class SiameseGLVQ(GLVQ):
else: else:
return optimizer return optimizer
def _forward(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
@ -256,7 +183,7 @@ class GRLVQ(SiameseGLVQ):
def relevance_profile(self): def relevance_profile(self):
return self.relevances.detach().cpu() return self.relevances.detach().cpu()
def _forward(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, torch.diag(self.relevances)) distances = self.distance_layer(x, protos, torch.diag(self.relevances))
return distances return distances
@ -285,7 +212,7 @@ class SiameseGMLVQ(SiameseGLVQ):
lam = omega.T @ omega lam = omega.T @ omega
return lam.detach().cpu() return lam.detach().cpu()
def _forward(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = get_flat(x, protos) x, protos = get_flat(x, protos)
latent_x = self.backbone(x) latent_x = self.backbone(x)
@ -305,7 +232,7 @@ class LVQMLN(SiameseGLVQ):
rather in the embedding space. rather in the embedding space.
""" """
def _forward(self, x): def compute_distances(self, x):
latent_protos, _ = self.proto_layer() latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
distances = self.distance_layer(latent_x, latent_protos) distances = self.distance_layer(latent_x, latent_protos)
@ -327,7 +254,7 @@ class GMLVQ(GLVQ):
device=self.device) device=self.device)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
def _forward(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega) distances = self.distance_layer(x, protos, self._omega)
return distances return distances
@ -355,7 +282,7 @@ class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1.""" """Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq1_loss self.loss = LossLayer(lvq1_loss)
self.optimizer = torch.optim.SGD self.optimizer = torch.optim.SGD
@ -363,11 +290,11 @@ class GLVQ21(GLVQ):
"""Generalized Learning Vector Quantization 2.1.""" """Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.loss = lvq21_loss self.loss = LossLayer(lvq21_loss)
self.optimizer = torch.optim.SGD self.optimizer = torch.optim.SGD
class ImageGLVQ(PrototypeImageModel, GLVQ): class ImageGLVQ(ImagePrototypesMixin, GLVQ):
"""GLVQ for training on image data. """GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping GLVQ model that constrains the prototypes to the range [0, 1] by clamping
@ -376,7 +303,7 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
""" """
class ImageGMLVQ(PrototypeImageModel, GMLVQ): class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
"""GMLVQ for training on image data. """GMLVQ for training on image data.
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping GMLVQ model that constrains the prototypes to the range [0, 1] by clamping

38
prototorch/models/knn.py Normal file
View File

@ -0,0 +1,38 @@
"""ProtoTorch KNN model."""
import warnings
from prototorch.components import LabeledComponents
from prototorch.modules import KNNC
from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("k", 1)
data = kwargs.get("data", None)
if data is None:
raise ValueError("KNN requires data, but was not provided!")
# Layers
self.proto_layer = LabeledComponents(initialized_components=data)
self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None

View File

@ -1,34 +1,24 @@
"""LVQ models that are optimized using non-gradient methods.""" """LVQ models that are optimized using non-gradient methods."""
from prototorch.functions.competitions import wtac
from prototorch.functions.losses import _get_dp_dm from prototorch.functions.losses import _get_dp_dm
from .abstract import NonGradientMixin
from .glvq import GLVQ from .glvq import GLVQ
class NonGradientLVQ(GLVQ): class LVQ1(NonGradientMixin, GLVQ):
"""Abstract Model for Models that do not use gradients in their update phase."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class LVQ1(NonGradientLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos = self.proto_layer.components
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
x, y = train_batch x, y = train_batch
dis = self._forward(x) dis = self.compute_distances(x)
# TODO Vectorized implementation # TODO Vectorized implementation
for xi, yi in zip(x, y): for xi, yi in zip(x, y):
d = self._forward(xi.view(1, -1)) d = self.compute_distances(xi.view(1, -1))
preds = wtac(d, plabels) preds = self.competition_layer(d, plabels)
w = d.argmin(1) w = d.argmin(1)
if yi == preds: if yi == preds:
shift = xi - protos[w] shift = xi - protos[w]
@ -45,20 +35,20 @@ class LVQ1(NonGradientLVQ):
return None return None
class LVQ21(NonGradientLVQ): class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1.""" """Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos = self.proto_layer.components
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
x, y = train_batch x, y = train_batch
dis = self._forward(x) dis = self.compute_distances(x)
# TODO Vectorized implementation # TODO Vectorized implementation
for xi, yi in zip(x, y): for xi, yi in zip(x, y):
xi = xi.view(1, -1) xi = xi.view(1, -1)
yi = yi.view(1, ) yi = yi.view(1, )
d = self._forward(xi) d = self.compute_distances(xi)
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True) (_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp] shiftp = xi - protos[wp]
shiftn = protos[wn] - xi shiftn = protos[wn] - xi
@ -74,5 +64,5 @@ class LVQ21(NonGradientLVQ):
return None return None
class MedianLVQ(NonGradientLVQ): class MedianLVQ(NonGradientMixin, GLVQ):
"""Median LVQ""" """Median LVQ"""

View File

@ -1,10 +1,10 @@
"""Probabilistic GLVQ methods""" """Probabilistic GLVQ methods"""
import torch import torch
from prototorch.functions.competitions import stratified_min, stratified_sum from prototorch.functions.losses import nllr_loss, rslvq_loss
from prototorch.functions.losses import (log_likelihood_ratio_loss, from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
robust_soft_loss)
from prototorch.functions.transforms import gaussian from prototorch.functions.transforms import gaussian
from prototorch.modules import LambdaLayer, LossLayer
from .glvq import GLVQ from .glvq import GLVQ
@ -13,13 +13,16 @@ class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization.""" """Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Loss
self.loss = torch.nn.CrossEntropyLoss() self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self._forward(x) # [None, num_protos] out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes] winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long()) batch_loss = self.loss(probs, y.long())
loss = batch_loss.sum(dim=0) loss = batch_loss.sum(dim=0)
return out, loss return out, loss
@ -33,14 +36,14 @@ class ProbabilisticLVQ(GLVQ):
self.rejection_confidence = rejection_confidence self.rejection_confidence = rejection_confidence
def forward(self, x): def forward(self, x):
distances = self._forward(x) distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances, conditional = self.conditional_distribution(distances,
self.hparams.variance) self.hparams.variance)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes, prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device) device=self.device)
posterior = conditional * prior posterior = conditional * prior
plabels = self.proto_layer._labels plabels = self.proto_layer._labels
y_pred = stratified_sum(posterior, plabels) y_pred = stratified_sum_pooling(posterior, plabels)
return y_pred return y_pred
def predict(self, x): def predict(self, x):
@ -50,12 +53,11 @@ class ProbabilisticLVQ(GLVQ):
return prediction return prediction
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
X, y = batch x, y = batch
out = self.forward(X) out = self.forward(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
batch_loss = self.loss_fn(out, y, plabels) batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum(dim=0) loss = batch_loss.sum(dim=0)
return loss return loss
@ -63,11 +65,11 @@ class LikelihoodRatioLVQ(ProbabilisticLVQ):
"""Learning Vector Quantization based on Likelihood Ratios.""" """Learning Vector Quantization based on Likelihood Ratios."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss_fn = log_likelihood_ratio_loss self.loss = LossLayer(nllr_loss)
class RSLVQ(ProbabilisticLVQ): class RSLVQ(ProbabilisticLVQ):
"""Robust Soft Learning Vector Quantization.""" """Robust Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss_fn = robust_soft_loss self.loss = LossLayer(rslvq_loss)

View File

@ -8,195 +8,30 @@ import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import Components, LabeledComponents from prototorch.components import Components, LabeledComponents
from prototorch.components.initializers import ZerosInitializer, parse_data_arg from prototorch.components.initializers import ZerosInitializer
from prototorch.functions.competitions import knnc from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules import LambdaLayer from prototorch.modules import LambdaLayer
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
from .abstract import AbstractPrototypeModel from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
class GNGCallback(Callback): class NeuralGas(UnsupervisedPrototypeModel):
"""GNG Callback.
Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components: pt.components.Components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
neighbors = topology.get_neighbors(worst)[0]
if len(neighbors) == 0:
logging.log(level=20, msg="No neighbor-pairs found!")
return
neighbors_errors = errors[neighbors]
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbor])
# Add component
pl_module.proto_layer.add_components(
initialized_components=new_component.unsqueeze(0))
# Adjust Topology
topology.add_prototype()
topology.add_connection(worst, -1)
topology.add_connection(worst_neighbor, -1)
topology.remove_connection(worst, worst_neighbor)
# New errors
worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit
self.num_prototypes = num_prototypes
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
self.age = torch.zeros_like(self.cmat)
def forward(self, d):
order = torch.argsort(d, dim=1)
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.cmat[i1][i0] = 1
self.age[i0][i1] = 0
self.age[i1][i0] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.age[i1][self.cmat[i1] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbors(self, position):
return torch.where(self.cmat[position])
def add_prototype(self):
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
new_cmat[:-1, :-1] = self.cmat
self.cmat = new_cmat
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
new_age[:-1, :-1] = self.age
self.age = new_age
def add_connection(self, a, b):
self.cmat[a][b] = 1
self.cmat[b][a] = 1
self.age[a][b] = 0
self.age[b][a] = 0
def remove_connection(self, a, b):
self.cmat[a][b] = 0
self.cmat[b][a] = 0
self.age[a][b] = 0
self.age[b][a] = 0
def extra_repr(self):
return f"(agelimit): ({self.agelimit})"
class KNN(AbstractPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__(hparams, **kwargs)
# Hyperparameters
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
# Default Values # Default hparams
self.hparams.setdefault("k", 1)
self.hparams.setdefault("distance", euclidean_distance)
data = kwargs.get("data")
x_train, y_train = parse_data_arg(data)
self.proto_layer = LabeledComponents(initialized_components=(x_train,
y_train))
self.train_acc = torchmetrics.Accuracy()
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach()
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None
class NeuralGas(AbstractPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
# Default Values
self.hparams.setdefault("input_dim", 2) self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1) self.hparams.setdefault("lm", 1)
self.proto_layer = Components(
self.hparams.num_prototypes,
initializer=kwargs.get("prototype_initializer"))
self.distance_layer = LambdaLayer(euclidean_distance)
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm) self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology( self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit, agelimit=self.hparams.agelimit,
@ -204,9 +39,10 @@ class NeuralGas(AbstractPrototypeModel):
) )
def training_step(self, train_batch, batch_idx): def training_step(self, train_batch, batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0] x = train_batch[0]
protos = self.proto_layer() d = self.compute_distances(x)
d = self.distance_layer(x, protos)
cost, _ = self.energy_layer(d) cost, _ = self.energy_layer(d)
self.topology_layer(d) self.topology_layer(d)
return cost return cost
@ -216,26 +52,26 @@ class GrowingNeuralGas(NeuralGas):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# defaults # Defaults
self.hparams.setdefault("step_reduction", 0.5) self.hparams.setdefault("step_reduction", 0.5)
self.hparams.setdefault("insert_reduction", 0.1) self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10) self.hparams.setdefault("insert_freq", 10)
self.register_buffer( errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
"errors", self.register_buffer("errors", errors)
torch.zeros(self.hparams.num_prototypes, device=self.device))
def training_step(self, train_batch, _batch_idx): def training_step(self, train_batch, _batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0] x = train_batch[0]
protos = self.proto_layer() d = self.compute_distances(x)
d = self.distance_layer(x, protos)
cost, order = self.energy_layer(d) cost, order = self.energy_layer(d)
winner = order[:, 0] winner = order[:, 0]
mask = torch.zeros_like(d) mask = torch.zeros_like(d)
mask[torch.arange(len(mask)), winner] = 1.0 mask[torch.arange(len(mask)), winner] = 1.0
winner_distances = d * mask dp = d * mask
self.errors += torch.sum(winner_distances * winner_distances, dim=0) self.errors += torch.sum(dp * dp, dim=0)
self.errors *= self.hparams.step_reduction self.errors *= self.hparams.step_reduction
self.topology_layer(d) self.topology_layer(d)

View File

@ -140,7 +140,7 @@ class VisGLVQ2D(Vis2DAbstract):
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.Tensor(mesh_input).type_as(_components) mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input) y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape) y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)