[REFACTOR] Major cleanup
This commit is contained in:
parent
20471bfb1c
commit
016fcb4060
@ -4,11 +4,23 @@ from importlib.metadata import PackageNotFoundError, version
|
|||||||
|
|
||||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||||
from .cbc import CBC, ImageCBC
|
from .cbc import CBC, ImageCBC
|
||||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
from .glvq import (
|
||||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
GLVQ,
|
||||||
|
GLVQ1,
|
||||||
|
GLVQ21,
|
||||||
|
GMLVQ,
|
||||||
|
GRLVQ,
|
||||||
|
LGMLVQ,
|
||||||
|
LVQMLN,
|
||||||
|
ImageGLVQ,
|
||||||
|
ImageGMLVQ,
|
||||||
|
SiameseGLVQ,
|
||||||
|
SiameseGMLVQ,
|
||||||
|
)
|
||||||
|
from .knn import KNN
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||||
from .unsupervised import KNN, GrowingNeuralGas, NeuralGas
|
from .unsupervised import GrowingNeuralGas, NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.1.7"
|
__version__ = "0.1.7"
|
||||||
|
@ -1,7 +1,39 @@
|
|||||||
|
"""Abstract classes to be inherited by prototorch models."""
|
||||||
|
|
||||||
|
from typing import Final, final
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
from prototorch.components import Components, LabeledComponents
|
||||||
|
from prototorch.functions.distances import euclidean_distance
|
||||||
|
from prototorch.modules import WTAC, LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
class AbstractPrototypeModel(pl.LightningModule):
|
class ProtoTorchBolt(pl.LightningModule):
|
||||||
|
def __repr__(self):
|
||||||
|
super_repr = super().__repr__()
|
||||||
|
return f"ProtoTorch Bolt:\n{super_repr}"
|
||||||
|
|
||||||
|
|
||||||
|
class PrototypeModel(ProtoTorchBolt):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
# Default hparams
|
||||||
|
self.hparams.setdefault("lr", 0.01)
|
||||||
|
|
||||||
|
# Default config
|
||||||
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||||
|
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||||
|
|
||||||
|
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||||
|
self.distance_layer = LambdaLayer(distance_fn)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_prototypes(self):
|
def num_prototypes(self):
|
||||||
return len(self.proto_layer.components)
|
return len(self.proto_layer.components)
|
||||||
@ -28,9 +60,115 @@ class AbstractPrototypeModel(pl.LightningModule):
|
|||||||
else:
|
else:
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
@final
|
||||||
|
def reconfigure_optimizers(self):
|
||||||
|
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||||
|
|
||||||
class PrototypeImageModel(pl.LightningModule):
|
def add_prototypes(self, *args, **kwargs):
|
||||||
|
self.proto_layer.add_components(*args, **kwargs)
|
||||||
|
self.reconfigure_optimizers()
|
||||||
|
|
||||||
|
def remove_prototypes(self, indices):
|
||||||
|
self.proto_layer.remove_components(indices)
|
||||||
|
self.reconfigure_optimizers()
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupervisedPrototypeModel(PrototypeModel):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Layers
|
||||||
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
|
if prototype_initializer is not None:
|
||||||
|
self.proto_layer = Components(
|
||||||
|
self.hparams.num_prototypes,
|
||||||
|
initializer=prototype_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_distances(self, x):
|
||||||
|
protos = self.proto_layer()
|
||||||
|
distances = self.distance_layer(x, protos)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
distances = self.compute_distances(x)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedPrototypeModel(PrototypeModel):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Layers
|
||||||
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
|
if prototype_initializer is not None:
|
||||||
|
self.proto_layer = LabeledComponents(
|
||||||
|
distribution=self.hparams.distribution,
|
||||||
|
initializer=prototype_initializer,
|
||||||
|
)
|
||||||
|
self.competition_layer = WTAC()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototype_labels(self):
|
||||||
|
return self.proto_layer.component_labels.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return len(self.proto_layer.distribution)
|
||||||
|
|
||||||
|
def compute_distances(self, x):
|
||||||
|
protos, _ = self.proto_layer()
|
||||||
|
distances = self.distance_layer(x, protos)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
distances = self.compute_distances(x)
|
||||||
|
y_pred = self.predict_from_distances(distances)
|
||||||
|
# TODO
|
||||||
|
y_pred = torch.eye(self.num_classes, device=self.device)[
|
||||||
|
y_pred.long()] # depends on labels {0,...,num_classes}
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
def predict_from_distances(self, distances):
|
||||||
|
with torch.no_grad():
|
||||||
|
plabels = self.proto_layer.component_labels
|
||||||
|
y_pred = self.competition_layer(distances, plabels)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
def predict(self, x):
|
||||||
|
with torch.no_grad():
|
||||||
|
distances = self.compute_distances(x)
|
||||||
|
y_pred = self.predict_from_distances(distances)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
def log_acc(self, distances, targets, tag):
|
||||||
|
preds = self.predict_from_distances(distances)
|
||||||
|
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||||
|
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||||
|
|
||||||
|
self.log(tag,
|
||||||
|
accuracy,
|
||||||
|
on_step=False,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True)
|
||||||
|
|
||||||
|
|
||||||
|
class NonGradientMixin():
|
||||||
|
"""Mixin for custom non-gradient optimization."""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.automatic_optimization: Final = False
|
||||||
|
|
||||||
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePrototypesMixin(ProtoTorchBolt):
|
||||||
|
"""Mixin for models with image prototypes."""
|
||||||
|
@final
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||||
|
|
||||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
"""Lightning Callbacks."""
|
"""Lightning Callbacks."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.components import Components
|
||||||
|
|
||||||
|
from .extras import ConnectionTopology
|
||||||
|
|
||||||
|
|
||||||
class PruneLoserPrototypes(pl.Callback):
|
class PruneLoserPrototypes(pl.Callback):
|
||||||
@ -26,25 +31,29 @@ class PruneLoserPrototypes(pl.Callback):
|
|||||||
return None
|
return None
|
||||||
if (trainer.current_epoch + 1) % self.frequency:
|
if (trainer.current_epoch + 1) % self.frequency:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ratios = pl_module.prototype_win_ratios.mean(dim=0)
|
ratios = pl_module.prototype_win_ratios.mean(dim=0)
|
||||||
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
|
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
|
||||||
prune_labels = pl_module.prototype_labels[to_prune.tolist()]
|
to_prune = to_prune.tolist()
|
||||||
|
prune_labels = pl_module.prototype_labels[to_prune]
|
||||||
if self.prune_quota_per_epoch > 0:
|
if self.prune_quota_per_epoch > 0:
|
||||||
to_prune = to_prune[:self.prune_quota_per_epoch]
|
to_prune = to_prune[:self.prune_quota_per_epoch]
|
||||||
prune_labels = prune_labels[:self.prune_quota_per_epoch]
|
prune_labels = prune_labels[:self.prune_quota_per_epoch]
|
||||||
|
|
||||||
if len(to_prune) > 0:
|
if len(to_prune) > 0:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"\nPrototype win ratios: {ratios}")
|
print(f"\nPrototype win ratios: {ratios}")
|
||||||
print(f"Pruning prototypes at: {to_prune.tolist()}")
|
print(f"Pruning prototypes at: {to_prune}")
|
||||||
|
print(f"Corresponding labels are: {prune_labels}")
|
||||||
cur_num_protos = pl_module.num_prototypes
|
cur_num_protos = pl_module.num_prototypes
|
||||||
pl_module.remove_prototypes(indices=to_prune)
|
pl_module.remove_prototypes(indices=to_prune)
|
||||||
if self.replace:
|
if self.replace:
|
||||||
if self.verbose:
|
|
||||||
print(f"Re-adding prototypes at: {to_prune.tolist()}")
|
|
||||||
labels, counts = torch.unique(prune_labels,
|
labels, counts = torch.unique(prune_labels,
|
||||||
sorted=True,
|
sorted=True,
|
||||||
return_counts=True)
|
return_counts=True)
|
||||||
distribution = dict(zip(labels.tolist(), counts.tolist()))
|
distribution = dict(zip(labels.tolist(), counts.tolist()))
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Re-adding pruned prototypes...")
|
||||||
print(f"{distribution=}")
|
print(f"{distribution=}")
|
||||||
pl_module.add_prototypes(distribution=distribution,
|
pl_module.add_prototypes(distribution=distribution,
|
||||||
initializer=self.initializer)
|
initializer=self.initializer)
|
||||||
@ -68,3 +77,58 @@ class PrototypeConvergence(pl.Callback):
|
|||||||
print("Stopping...")
|
print("Stopping...")
|
||||||
# TODO
|
# TODO
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class GNGCallback(pl.Callback):
|
||||||
|
"""GNG Callback.
|
||||||
|
|
||||||
|
Applies growing algorithm based on accumulated error and topology.
|
||||||
|
|
||||||
|
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, reduction=0.1, freq=10):
|
||||||
|
self.reduction = reduction
|
||||||
|
self.freq = freq
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
|
||||||
|
if (trainer.current_epoch + 1) % self.freq == 0:
|
||||||
|
# Get information
|
||||||
|
errors = pl_module.errors
|
||||||
|
topology: ConnectionTopology = pl_module.topology_layer
|
||||||
|
components: Components = pl_module.proto_layer.components
|
||||||
|
|
||||||
|
# Insertion point
|
||||||
|
worst = torch.argmax(errors)
|
||||||
|
|
||||||
|
neighbors = topology.get_neighbors(worst)[0]
|
||||||
|
|
||||||
|
if len(neighbors) == 0:
|
||||||
|
logging.log(level=20, msg="No neighbor-pairs found!")
|
||||||
|
return
|
||||||
|
|
||||||
|
neighbors_errors = errors[neighbors]
|
||||||
|
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
|
||||||
|
|
||||||
|
# New Prototype
|
||||||
|
new_component = 0.5 * (components[worst] +
|
||||||
|
components[worst_neighbor])
|
||||||
|
|
||||||
|
# Add component
|
||||||
|
pl_module.proto_layer.add_components(
|
||||||
|
initialized_components=new_component.unsqueeze(0))
|
||||||
|
|
||||||
|
# Adjust Topology
|
||||||
|
topology.add_prototype()
|
||||||
|
topology.add_connection(worst, -1)
|
||||||
|
topology.add_connection(worst_neighbor, -1)
|
||||||
|
topology.remove_connection(worst, worst_neighbor)
|
||||||
|
|
||||||
|
# New errors
|
||||||
|
worst_error = errors[worst].unsqueeze(0)
|
||||||
|
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
||||||
|
pl_module.errors[worst] = errors[worst] * self.reduction
|
||||||
|
pl_module.errors[
|
||||||
|
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||||
|
|
||||||
|
trainer.accelerator_backend.setup_optimizers(trainer)
|
||||||
|
@ -1,86 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.functions.distances import euclidean_distance
|
|
||||||
from prototorch.functions.similarities import cosine_similarity
|
|
||||||
|
|
||||||
|
from .abstract import ImagePrototypesMixin
|
||||||
|
from .extras import (
|
||||||
|
CosineSimilarity,
|
||||||
|
MarginLoss,
|
||||||
|
ReasoningLayer,
|
||||||
|
euclidean_similarity,
|
||||||
|
rescaled_cosine_similarity,
|
||||||
|
shift_activation,
|
||||||
|
)
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
def rescaled_cosine_similarity(x, y):
|
|
||||||
"""Cosine Similarity rescaled to [0, 1]."""
|
|
||||||
similarities = cosine_similarity(x, y)
|
|
||||||
return (similarities + 1.0) / 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def shift_activation(x):
|
|
||||||
return (x + 1.0) / 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def euclidean_similarity(x, y, variance=1.0):
|
|
||||||
d = euclidean_distance(x, y)
|
|
||||||
return torch.exp(-(d * d) / (2 * variance))
|
|
||||||
|
|
||||||
|
|
||||||
class CosineSimilarity(torch.nn.Module):
|
|
||||||
def __init__(self, activation=shift_activation):
|
|
||||||
super().__init__()
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
epsilon = torch.finfo(x.dtype).eps
|
|
||||||
normed_x = (x / x.pow(2).sum(dim=tuple(range(
|
|
||||||
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
|
||||||
start_dim=1)
|
|
||||||
normed_y = (y / y.pow(2).sum(dim=tuple(range(
|
|
||||||
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
|
||||||
start_dim=1)
|
|
||||||
# normed_x = (x / torch.linalg.norm(x, dim=1))
|
|
||||||
diss = torch.inner(normed_x, normed_y)
|
|
||||||
return self.activation(diss)
|
|
||||||
|
|
||||||
|
|
||||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
|
||||||
def __init__(self,
|
|
||||||
margin=0.3,
|
|
||||||
size_average=None,
|
|
||||||
reduce=None,
|
|
||||||
reduction="mean"):
|
|
||||||
super().__init__(size_average, reduce, reduction)
|
|
||||||
self.margin = margin
|
|
||||||
|
|
||||||
def forward(self, input_, target):
|
|
||||||
dp = torch.sum(target * input_, dim=-1)
|
|
||||||
dm = torch.max(input_ - target, dim=-1).values
|
|
||||||
return torch.nn.functional.relu(dm - dp + self.margin)
|
|
||||||
|
|
||||||
|
|
||||||
class ReasoningLayer(torch.nn.Module):
|
|
||||||
def __init__(self, num_components, num_classes, num_replicas=1):
|
|
||||||
super().__init__()
|
|
||||||
self.num_replicas = num_replicas
|
|
||||||
self.num_classes = num_classes
|
|
||||||
probabilities_init = torch.zeros(2, 1, num_components,
|
|
||||||
self.num_classes)
|
|
||||||
probabilities_init.uniform_(0.4, 0.6)
|
|
||||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reasonings(self):
|
|
||||||
pk = self.reasoning_probabilities[0]
|
|
||||||
nk = (1 - pk) * self.reasoning_probabilities[1]
|
|
||||||
ik = 1 - pk - nk
|
|
||||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
|
||||||
return img.unsqueeze(1)
|
|
||||||
|
|
||||||
def forward(self, detections):
|
|
||||||
pk = self.reasoning_probabilities[0].clamp(0, 1)
|
|
||||||
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
|
|
||||||
numerator = (detections @ (pk - nk)) + nk.sum(1)
|
|
||||||
probs = numerator / (pk + nk).sum(1)
|
|
||||||
probs = probs.squeeze(0)
|
|
||||||
return probs
|
|
||||||
|
|
||||||
|
|
||||||
class CBC(SiameseGLVQ):
|
class CBC(SiameseGLVQ):
|
||||||
"""Classification-By-Components."""
|
"""Classification-By-Components."""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -143,10 +75,11 @@ class CBC(SiameseGLVQ):
|
|||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class ImageCBC(CBC):
|
class ImageCBC(ImagePrototypesMixin, CBC):
|
||||||
"""CBC model that constrains the components to the range [0, 1] by
|
"""CBC model that constrains the components to the range [0, 1] by
|
||||||
clamping after updates.
|
clamping after updates.
|
||||||
"""
|
"""
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
def __init__(self, hparams, **kwargs):
|
||||||
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
super().__init__(hparams, **kwargs)
|
||||||
self.component_layer.components.data.clamp_(0.0, 1.0)
|
# Namespace hook
|
||||||
|
self.proto_layer = self.component_layer
|
||||||
|
142
prototorch/models/extras.py
Normal file
142
prototorch/models/extras.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
"""prototorch.models.extras
|
||||||
|
|
||||||
|
Modules not yet available in prototorch go here temporarily.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.functions.distances import euclidean_distance
|
||||||
|
from prototorch.functions.similarities import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
|
def rescaled_cosine_similarity(x, y):
|
||||||
|
"""Cosine Similarity rescaled to [0, 1]."""
|
||||||
|
similarities = cosine_similarity(x, y)
|
||||||
|
return (similarities + 1.0) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def shift_activation(x):
|
||||||
|
return (x + 1.0) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_similarity(x, y, variance=1.0):
|
||||||
|
d = euclidean_distance(x, y)
|
||||||
|
return torch.exp(-(d * d) / (2 * variance))
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionTopology(torch.nn.Module):
|
||||||
|
def __init__(self, agelimit, num_prototypes):
|
||||||
|
super().__init__()
|
||||||
|
self.agelimit = agelimit
|
||||||
|
self.num_prototypes = num_prototypes
|
||||||
|
|
||||||
|
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
|
||||||
|
self.age = torch.zeros_like(self.cmat)
|
||||||
|
|
||||||
|
def forward(self, d):
|
||||||
|
order = torch.argsort(d, dim=1)
|
||||||
|
|
||||||
|
for element in order:
|
||||||
|
i0, i1 = element[0], element[1]
|
||||||
|
|
||||||
|
self.cmat[i0][i1] = 1
|
||||||
|
self.cmat[i1][i0] = 1
|
||||||
|
|
||||||
|
self.age[i0][i1] = 0
|
||||||
|
self.age[i1][i0] = 0
|
||||||
|
|
||||||
|
self.age[i0][self.cmat[i0] == 1] += 1
|
||||||
|
self.age[i1][self.cmat[i1] == 1] += 1
|
||||||
|
|
||||||
|
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
||||||
|
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
||||||
|
|
||||||
|
def get_neighbors(self, position):
|
||||||
|
return torch.where(self.cmat[position])
|
||||||
|
|
||||||
|
def add_prototype(self):
|
||||||
|
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
|
||||||
|
new_cmat[:-1, :-1] = self.cmat
|
||||||
|
self.cmat = new_cmat
|
||||||
|
|
||||||
|
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
|
||||||
|
new_age[:-1, :-1] = self.age
|
||||||
|
self.age = new_age
|
||||||
|
|
||||||
|
def add_connection(self, a, b):
|
||||||
|
self.cmat[a][b] = 1
|
||||||
|
self.cmat[b][a] = 1
|
||||||
|
|
||||||
|
self.age[a][b] = 0
|
||||||
|
self.age[b][a] = 0
|
||||||
|
|
||||||
|
def remove_connection(self, a, b):
|
||||||
|
self.cmat[a][b] = 0
|
||||||
|
self.cmat[b][a] = 0
|
||||||
|
|
||||||
|
self.age[a][b] = 0
|
||||||
|
self.age[b][a] = 0
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"(agelimit): ({self.agelimit})"
|
||||||
|
|
||||||
|
|
||||||
|
class CosineSimilarity(torch.nn.Module):
|
||||||
|
def __init__(self, activation=shift_activation):
|
||||||
|
super().__init__()
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
epsilon = torch.finfo(x.dtype).eps
|
||||||
|
normed_x = (x / x.pow(2).sum(dim=tuple(range(
|
||||||
|
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||||
|
start_dim=1)
|
||||||
|
normed_y = (y / y.pow(2).sum(dim=tuple(range(
|
||||||
|
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||||
|
start_dim=1)
|
||||||
|
# normed_x = (x / torch.linalg.norm(x, dim=1))
|
||||||
|
diss = torch.inner(normed_x, normed_y)
|
||||||
|
return self.activation(diss)
|
||||||
|
|
||||||
|
|
||||||
|
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||||
|
def __init__(self,
|
||||||
|
margin=0.3,
|
||||||
|
size_average=None,
|
||||||
|
reduce=None,
|
||||||
|
reduction="mean"):
|
||||||
|
super().__init__(size_average, reduce, reduction)
|
||||||
|
self.margin = margin
|
||||||
|
|
||||||
|
def forward(self, input_, target):
|
||||||
|
dp = torch.sum(target * input_, dim=-1)
|
||||||
|
dm = torch.max(input_ - target, dim=-1).values
|
||||||
|
return torch.nn.functional.relu(dm - dp + self.margin)
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningLayer(torch.nn.Module):
|
||||||
|
def __init__(self, num_components, num_classes, num_replicas=1):
|
||||||
|
super().__init__()
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.num_classes = num_classes
|
||||||
|
probabilities_init = torch.zeros(2, 1, num_components,
|
||||||
|
self.num_classes)
|
||||||
|
probabilities_init.uniform_(0.4, 0.6)
|
||||||
|
# TODO Use `self.register_parameter("param", Paramater(param))` instead
|
||||||
|
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reasonings(self):
|
||||||
|
pk = self.reasoning_probabilities[0]
|
||||||
|
nk = (1 - pk) * self.reasoning_probabilities[1]
|
||||||
|
ik = 1 - pk - nk
|
||||||
|
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
||||||
|
return img.unsqueeze(1)
|
||||||
|
|
||||||
|
def forward(self, detections):
|
||||||
|
pk = self.reasoning_probabilities[0].clamp(0, 1)
|
||||||
|
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
|
||||||
|
numerator = (detections @ (pk - nk)) + nk.sum(1)
|
||||||
|
probs = numerator / (pk + nk).sum(1)
|
||||||
|
probs = probs.squeeze(0)
|
||||||
|
return probs
|
@ -1,101 +1,40 @@
|
|||||||
"""Models based on the GLVQ framework."""
|
"""Models based on the GLVQ framework."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
|
||||||
from prototorch.components import LabeledComponents
|
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (
|
from prototorch.functions.distances import (
|
||||||
euclidean_distance,
|
|
||||||
lomega_distance,
|
lomega_distance,
|
||||||
omega_distance,
|
omega_distance,
|
||||||
squared_euclidean_distance,
|
squared_euclidean_distance,
|
||||||
)
|
)
|
||||||
from prototorch.functions.helper import get_flat
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
from prototorch.modules import LambdaLayer
|
from prototorch.modules import LambdaLayer, LossLayer
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(AbstractPrototypeModel):
|
class GLVQ(SupervisedPrototypeModel):
|
||||||
"""Generalized Learning Vector Quantization."""
|
"""Generalized Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
super().__init__()
|
# Default hparams
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
self.save_hyperparameters(hparams)
|
|
||||||
|
|
||||||
# Defaults
|
|
||||||
self.hparams.setdefault("transfer_fn", "identity")
|
self.hparams.setdefault("transfer_fn", "identity")
|
||||||
self.hparams.setdefault("transfer_beta", 10.0)
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
self.hparams.setdefault("lr", 0.01)
|
|
||||||
|
|
||||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
|
||||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||||
self.proto_layer = LabeledComponents(
|
|
||||||
distribution=self.hparams.distribution,
|
|
||||||
initializer=prototype_initializer)
|
|
||||||
|
|
||||||
self.distance_layer = LambdaLayer(distance_fn)
|
|
||||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||||
self.loss = LambdaLayer(glvq_loss)
|
|
||||||
|
# Loss
|
||||||
|
self.loss = LossLayer(glvq_loss)
|
||||||
|
|
||||||
# Prototype metrics
|
# Prototype metrics
|
||||||
self.initialize_prototype_win_ratios()
|
self.initialize_prototype_win_ratios()
|
||||||
|
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
|
||||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
|
||||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prototype_labels(self):
|
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_classes(self):
|
|
||||||
return len(self.proto_layer.distribution)
|
|
||||||
|
|
||||||
def _forward(self, x):
|
|
||||||
protos, _ = self.proto_layer()
|
|
||||||
distances = self.distance_layer(x, protos)
|
|
||||||
return distances
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
distances = self._forward(x)
|
|
||||||
y_pred = self.predict_from_distances(distances)
|
|
||||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
def predict_from_distances(self, distances):
|
|
||||||
with torch.no_grad():
|
|
||||||
plabels = self.proto_layer.component_labels
|
|
||||||
y_pred = wtac(distances, plabels)
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
def predict(self, x):
|
|
||||||
with torch.no_grad():
|
|
||||||
distances = self._forward(x)
|
|
||||||
y_pred = self.predict_from_distances(distances)
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
def log_acc(self, distances, targets, tag):
|
|
||||||
preds = self.predict_from_distances(distances)
|
|
||||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
|
||||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
|
||||||
|
|
||||||
self.log(tag,
|
|
||||||
accuracy,
|
|
||||||
on_step=False,
|
|
||||||
on_epoch=True,
|
|
||||||
prog_bar=True,
|
|
||||||
logger=True)
|
|
||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"prototype_win_ratios",
|
"prototype_win_ratios",
|
||||||
@ -121,7 +60,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self._forward(x)
|
out = self.compute_distances(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = self.loss(out, y, prototype_labels=plabels)
|
mu = self.loss(out, y, prototype_labels=plabels)
|
||||||
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
||||||
@ -158,18 +97,6 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
def add_prototypes(self, initializer, distribution):
|
|
||||||
self.proto_layer.add_components(initializer, distribution)
|
|
||||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
|
||||||
|
|
||||||
def remove_prototypes(self, indices):
|
|
||||||
self.proto_layer.remove_components(indices)
|
|
||||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super_repr = super().__repr__()
|
|
||||||
return f"{super_repr}"
|
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(GLVQ):
|
class SiameseGLVQ(GLVQ):
|
||||||
"""GLVQ in a Siamese setting.
|
"""GLVQ in a Siamese setting.
|
||||||
@ -212,7 +139,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
else:
|
else:
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def _forward(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
@ -256,7 +183,7 @@ class GRLVQ(SiameseGLVQ):
|
|||||||
def relevance_profile(self):
|
def relevance_profile(self):
|
||||||
return self.relevances.detach().cpu()
|
return self.relevances.detach().cpu()
|
||||||
|
|
||||||
def _forward(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
|
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
|
||||||
return distances
|
return distances
|
||||||
@ -285,7 +212,7 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
lam = omega.T @ omega
|
lam = omega.T @ omega
|
||||||
return lam.detach().cpu()
|
return lam.detach().cpu()
|
||||||
|
|
||||||
def _forward(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = get_flat(x, protos)
|
x, protos = get_flat(x, protos)
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
@ -305,7 +232,7 @@ class LVQMLN(SiameseGLVQ):
|
|||||||
rather in the embedding space.
|
rather in the embedding space.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def _forward(self, x):
|
def compute_distances(self, x):
|
||||||
latent_protos, _ = self.proto_layer()
|
latent_protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
distances = self.distance_layer(latent_x, latent_protos)
|
distances = self.distance_layer(latent_x, latent_protos)
|
||||||
@ -327,7 +254,7 @@ class GMLVQ(GLVQ):
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
def _forward(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
distances = self.distance_layer(x, protos, self._omega)
|
distances = self.distance_layer(x, protos, self._omega)
|
||||||
return distances
|
return distances
|
||||||
@ -355,7 +282,7 @@ class GLVQ1(GLVQ):
|
|||||||
"""Generalized Learning Vector Quantization 1."""
|
"""Generalized Learning Vector Quantization 1."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.loss = lvq1_loss
|
self.loss = LossLayer(lvq1_loss)
|
||||||
self.optimizer = torch.optim.SGD
|
self.optimizer = torch.optim.SGD
|
||||||
|
|
||||||
|
|
||||||
@ -363,11 +290,11 @@ class GLVQ21(GLVQ):
|
|||||||
"""Generalized Learning Vector Quantization 2.1."""
|
"""Generalized Learning Vector Quantization 2.1."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.loss = lvq21_loss
|
self.loss = LossLayer(lvq21_loss)
|
||||||
self.optimizer = torch.optim.SGD
|
self.optimizer = torch.optim.SGD
|
||||||
|
|
||||||
|
|
||||||
class ImageGLVQ(PrototypeImageModel, GLVQ):
|
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
|
||||||
"""GLVQ for training on image data.
|
"""GLVQ for training on image data.
|
||||||
|
|
||||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||||
@ -376,7 +303,7 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
|
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
||||||
"""GMLVQ for training on image data.
|
"""GMLVQ for training on image data.
|
||||||
|
|
||||||
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||||
|
38
prototorch/models/knn.py
Normal file
38
prototorch/models/knn.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""ProtoTorch KNN model."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from prototorch.components import LabeledComponents
|
||||||
|
from prototorch.modules import KNNC
|
||||||
|
|
||||||
|
from .abstract import SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
|
class KNN(SupervisedPrototypeModel):
|
||||||
|
"""K-Nearest-Neighbors classification algorithm."""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Default hparams
|
||||||
|
self.hparams.setdefault("k", 1)
|
||||||
|
|
||||||
|
data = kwargs.get("data", None)
|
||||||
|
if data is None:
|
||||||
|
raise ValueError("KNN requires data, but was not provided!")
|
||||||
|
|
||||||
|
# Layers
|
||||||
|
self.proto_layer = LabeledComponents(initialized_components=data)
|
||||||
|
self.competition_layer = KNNC(k=self.hparams.k)
|
||||||
|
|
||||||
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
|
return 1 # skip training step
|
||||||
|
|
||||||
|
def on_train_batch_start(self,
|
||||||
|
train_batch,
|
||||||
|
batch_idx,
|
||||||
|
dataloader_idx=None):
|
||||||
|
warnings.warn("k-NN has no training, skipping!")
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return None
|
@ -1,34 +1,24 @@
|
|||||||
"""LVQ models that are optimized using non-gradient methods."""
|
"""LVQ models that are optimized using non-gradient methods."""
|
||||||
|
|
||||||
from prototorch.functions.competitions import wtac
|
|
||||||
from prototorch.functions.losses import _get_dp_dm
|
from prototorch.functions.losses import _get_dp_dm
|
||||||
|
|
||||||
|
from .abstract import NonGradientMixin
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
|
|
||||||
class NonGradientLVQ(GLVQ):
|
class LVQ1(NonGradientMixin, GLVQ):
|
||||||
"""Abstract Model for Models that do not use gradients in their update phase."""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.automatic_optimization = False
|
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class LVQ1(NonGradientLVQ):
|
|
||||||
"""Learning Vector Quantization 1."""
|
"""Learning Vector Quantization 1."""
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos = self.proto_layer.components
|
protos = self.proto_layer.components
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self._forward(x)
|
dis = self.compute_distances(x)
|
||||||
# TODO Vectorized implementation
|
# TODO Vectorized implementation
|
||||||
|
|
||||||
for xi, yi in zip(x, y):
|
for xi, yi in zip(x, y):
|
||||||
d = self._forward(xi.view(1, -1))
|
d = self.compute_distances(xi.view(1, -1))
|
||||||
preds = wtac(d, plabels)
|
preds = self.competition_layer(d, plabels)
|
||||||
w = d.argmin(1)
|
w = d.argmin(1)
|
||||||
if yi == preds:
|
if yi == preds:
|
||||||
shift = xi - protos[w]
|
shift = xi - protos[w]
|
||||||
@ -45,20 +35,20 @@ class LVQ1(NonGradientLVQ):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class LVQ21(NonGradientLVQ):
|
class LVQ21(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 2.1."""
|
"""Learning Vector Quantization 2.1."""
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos = self.proto_layer.components
|
protos = self.proto_layer.components
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self._forward(x)
|
dis = self.compute_distances(x)
|
||||||
# TODO Vectorized implementation
|
# TODO Vectorized implementation
|
||||||
|
|
||||||
for xi, yi in zip(x, y):
|
for xi, yi in zip(x, y):
|
||||||
xi = xi.view(1, -1)
|
xi = xi.view(1, -1)
|
||||||
yi = yi.view(1, )
|
yi = yi.view(1, )
|
||||||
d = self._forward(xi)
|
d = self.compute_distances(xi)
|
||||||
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
|
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
|
||||||
shiftp = xi - protos[wp]
|
shiftp = xi - protos[wp]
|
||||||
shiftn = protos[wn] - xi
|
shiftn = protos[wn] - xi
|
||||||
@ -74,5 +64,5 @@ class LVQ21(NonGradientLVQ):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class MedianLVQ(NonGradientLVQ):
|
class MedianLVQ(NonGradientMixin, GLVQ):
|
||||||
"""Median LVQ"""
|
"""Median LVQ"""
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Probabilistic GLVQ methods"""
|
"""Probabilistic GLVQ methods"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.competitions import stratified_min, stratified_sum
|
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
||||||
from prototorch.functions.losses import (log_likelihood_ratio_loss,
|
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||||
robust_soft_loss)
|
|
||||||
from prototorch.functions.transforms import gaussian
|
from prototorch.functions.transforms import gaussian
|
||||||
|
from prototorch.modules import LambdaLayer, LossLayer
|
||||||
|
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
@ -13,13 +13,16 @@ class CELVQ(GLVQ):
|
|||||||
"""Cross-Entropy Learning Vector Quantization."""
|
"""Cross-Entropy Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Loss
|
||||||
self.loss = torch.nn.CrossEntropyLoss()
|
self.loss = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self._forward(x) # [None, num_protos]
|
out = self.compute_distances(x) # [None, num_protos]
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
|
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
||||||
|
probs = -1.0 * winning
|
||||||
batch_loss = self.loss(probs, y.long())
|
batch_loss = self.loss(probs, y.long())
|
||||||
loss = batch_loss.sum(dim=0)
|
loss = batch_loss.sum(dim=0)
|
||||||
return out, loss
|
return out, loss
|
||||||
@ -33,14 +36,14 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
self.rejection_confidence = rejection_confidence
|
self.rejection_confidence = rejection_confidence
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self._forward(x)
|
distances = self.compute_distances(x)
|
||||||
conditional = self.conditional_distribution(distances,
|
conditional = self.conditional_distribution(distances,
|
||||||
self.hparams.variance)
|
self.hparams.variance)
|
||||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
posterior = conditional * prior
|
posterior = conditional * prior
|
||||||
plabels = self.proto_layer._labels
|
plabels = self.proto_layer._labels
|
||||||
y_pred = stratified_sum(posterior, plabels)
|
y_pred = stratified_sum_pooling(posterior, plabels)
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
@ -50,12 +53,11 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
X, y = batch
|
x, y = batch
|
||||||
out = self.forward(X)
|
out = self.forward(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
batch_loss = self.loss_fn(out, y, plabels)
|
batch_loss = self.loss(out, y, plabels)
|
||||||
loss = batch_loss.sum(dim=0)
|
loss = batch_loss.sum(dim=0)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@ -63,11 +65,11 @@ class LikelihoodRatioLVQ(ProbabilisticLVQ):
|
|||||||
"""Learning Vector Quantization based on Likelihood Ratios."""
|
"""Learning Vector Quantization based on Likelihood Ratios."""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.loss_fn = log_likelihood_ratio_loss
|
self.loss = LossLayer(nllr_loss)
|
||||||
|
|
||||||
|
|
||||||
class RSLVQ(ProbabilisticLVQ):
|
class RSLVQ(ProbabilisticLVQ):
|
||||||
"""Robust Soft Learning Vector Quantization."""
|
"""Robust Soft Learning Vector Quantization."""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.loss_fn = robust_soft_loss
|
self.loss = LossLayer(rslvq_loss)
|
||||||
|
@ -8,195 +8,30 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import Components, LabeledComponents
|
from prototorch.components import Components, LabeledComponents
|
||||||
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
from prototorch.components.initializers import ZerosInitializer
|
||||||
from prototorch.functions.competitions import knnc
|
from prototorch.functions.competitions import knnc
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules import LambdaLayer
|
from prototorch.modules import LambdaLayer
|
||||||
from prototorch.modules.losses import NeuralGasEnergy
|
from prototorch.modules.losses import NeuralGasEnergy
|
||||||
from pytorch_lightning.callbacks import Callback
|
from pytorch_lightning.callbacks import Callback
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel
|
from .abstract import UnsupervisedPrototypeModel
|
||||||
|
from .callbacks import GNGCallback
|
||||||
|
from .extras import ConnectionTopology
|
||||||
|
|
||||||
|
|
||||||
class GNGCallback(Callback):
|
class NeuralGas(UnsupervisedPrototypeModel):
|
||||||
"""GNG Callback.
|
|
||||||
|
|
||||||
Applies growing algorithm based on accumulated error and topology.
|
|
||||||
|
|
||||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
|
||||||
"""
|
|
||||||
def __init__(self, reduction=0.1, freq=10):
|
|
||||||
self.reduction = reduction
|
|
||||||
self.freq = freq
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
|
|
||||||
if (trainer.current_epoch + 1) % self.freq == 0:
|
|
||||||
# Get information
|
|
||||||
errors = pl_module.errors
|
|
||||||
topology: ConnectionTopology = pl_module.topology_layer
|
|
||||||
components: pt.components.Components = pl_module.proto_layer.components
|
|
||||||
|
|
||||||
# Insertion point
|
|
||||||
worst = torch.argmax(errors)
|
|
||||||
|
|
||||||
neighbors = topology.get_neighbors(worst)[0]
|
|
||||||
|
|
||||||
if len(neighbors) == 0:
|
|
||||||
logging.log(level=20, msg="No neighbor-pairs found!")
|
|
||||||
return
|
|
||||||
|
|
||||||
neighbors_errors = errors[neighbors]
|
|
||||||
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
|
|
||||||
|
|
||||||
# New Prototype
|
|
||||||
new_component = 0.5 * (components[worst] +
|
|
||||||
components[worst_neighbor])
|
|
||||||
|
|
||||||
# Add component
|
|
||||||
pl_module.proto_layer.add_components(
|
|
||||||
initialized_components=new_component.unsqueeze(0))
|
|
||||||
|
|
||||||
# Adjust Topology
|
|
||||||
topology.add_prototype()
|
|
||||||
topology.add_connection(worst, -1)
|
|
||||||
topology.add_connection(worst_neighbor, -1)
|
|
||||||
topology.remove_connection(worst, worst_neighbor)
|
|
||||||
|
|
||||||
# New errors
|
|
||||||
worst_error = errors[worst].unsqueeze(0)
|
|
||||||
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
|
||||||
pl_module.errors[worst] = errors[worst] * self.reduction
|
|
||||||
pl_module.errors[
|
|
||||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
|
||||||
|
|
||||||
trainer.accelerator_backend.setup_optimizers(trainer)
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTopology(torch.nn.Module):
|
|
||||||
def __init__(self, agelimit, num_prototypes):
|
|
||||||
super().__init__()
|
|
||||||
self.agelimit = agelimit
|
|
||||||
self.num_prototypes = num_prototypes
|
|
||||||
|
|
||||||
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
|
|
||||||
self.age = torch.zeros_like(self.cmat)
|
|
||||||
|
|
||||||
def forward(self, d):
|
|
||||||
order = torch.argsort(d, dim=1)
|
|
||||||
|
|
||||||
for element in order:
|
|
||||||
i0, i1 = element[0], element[1]
|
|
||||||
|
|
||||||
self.cmat[i0][i1] = 1
|
|
||||||
self.cmat[i1][i0] = 1
|
|
||||||
|
|
||||||
self.age[i0][i1] = 0
|
|
||||||
self.age[i1][i0] = 0
|
|
||||||
|
|
||||||
self.age[i0][self.cmat[i0] == 1] += 1
|
|
||||||
self.age[i1][self.cmat[i1] == 1] += 1
|
|
||||||
|
|
||||||
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
|
||||||
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
|
||||||
|
|
||||||
def get_neighbors(self, position):
|
|
||||||
return torch.where(self.cmat[position])
|
|
||||||
|
|
||||||
def add_prototype(self):
|
|
||||||
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
|
|
||||||
new_cmat[:-1, :-1] = self.cmat
|
|
||||||
self.cmat = new_cmat
|
|
||||||
|
|
||||||
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
|
|
||||||
new_age[:-1, :-1] = self.age
|
|
||||||
self.age = new_age
|
|
||||||
|
|
||||||
def add_connection(self, a, b):
|
|
||||||
self.cmat[a][b] = 1
|
|
||||||
self.cmat[b][a] = 1
|
|
||||||
|
|
||||||
self.age[a][b] = 0
|
|
||||||
self.age[b][a] = 0
|
|
||||||
|
|
||||||
def remove_connection(self, a, b):
|
|
||||||
self.cmat[a][b] = 0
|
|
||||||
self.cmat[b][a] = 0
|
|
||||||
|
|
||||||
self.age[a][b] = 0
|
|
||||||
self.age[b][a] = 0
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"(agelimit): ({self.agelimit})"
|
|
||||||
|
|
||||||
|
|
||||||
class KNN(AbstractPrototypeModel):
|
|
||||||
"""K-Nearest-Neighbors classification algorithm."""
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__()
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
# Default Values
|
# Default hparams
|
||||||
self.hparams.setdefault("k", 1)
|
|
||||||
self.hparams.setdefault("distance", euclidean_distance)
|
|
||||||
|
|
||||||
data = kwargs.get("data")
|
|
||||||
x_train, y_train = parse_data_arg(data)
|
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(initialized_components=(x_train,
|
|
||||||
y_train))
|
|
||||||
|
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prototype_labels(self):
|
|
||||||
return self.proto_layer.component_labels.detach()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
protos, _ = self.proto_layer()
|
|
||||||
dis = self.hparams.distance(x, protos)
|
|
||||||
return dis
|
|
||||||
|
|
||||||
def predict(self, x):
|
|
||||||
# model.eval() # ?!
|
|
||||||
with torch.no_grad():
|
|
||||||
d = self(x)
|
|
||||||
plabels = self.proto_layer.component_labels
|
|
||||||
y_pred = knnc(d, plabels, k=self.hparams.k)
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def on_train_batch_start(self,
|
|
||||||
train_batch,
|
|
||||||
batch_idx,
|
|
||||||
dataloader_idx=None):
|
|
||||||
warnings.warn("k-NN has no training, skipping!")
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class NeuralGas(AbstractPrototypeModel):
|
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.save_hyperparameters(hparams)
|
|
||||||
|
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
|
||||||
|
|
||||||
# Default Values
|
|
||||||
self.hparams.setdefault("input_dim", 2)
|
self.hparams.setdefault("input_dim", 2)
|
||||||
self.hparams.setdefault("agelimit", 10)
|
self.hparams.setdefault("agelimit", 10)
|
||||||
self.hparams.setdefault("lm", 1)
|
self.hparams.setdefault("lm", 1)
|
||||||
|
|
||||||
self.proto_layer = Components(
|
|
||||||
self.hparams.num_prototypes,
|
|
||||||
initializer=kwargs.get("prototype_initializer"))
|
|
||||||
|
|
||||||
self.distance_layer = LambdaLayer(euclidean_distance)
|
|
||||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||||
self.topology_layer = ConnectionTopology(
|
self.topology_layer = ConnectionTopology(
|
||||||
agelimit=self.hparams.agelimit,
|
agelimit=self.hparams.agelimit,
|
||||||
@ -204,9 +39,10 @@ class NeuralGas(AbstractPrototypeModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx):
|
||||||
|
# x = train_batch
|
||||||
|
# TODO Check if the batch has labels
|
||||||
x = train_batch[0]
|
x = train_batch[0]
|
||||||
protos = self.proto_layer()
|
d = self.compute_distances(x)
|
||||||
d = self.distance_layer(x, protos)
|
|
||||||
cost, _ = self.energy_layer(d)
|
cost, _ = self.energy_layer(d)
|
||||||
self.topology_layer(d)
|
self.topology_layer(d)
|
||||||
return cost
|
return cost
|
||||||
@ -216,26 +52,26 @@ class GrowingNeuralGas(NeuralGas):
|
|||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# defaults
|
# Defaults
|
||||||
self.hparams.setdefault("step_reduction", 0.5)
|
self.hparams.setdefault("step_reduction", 0.5)
|
||||||
self.hparams.setdefault("insert_reduction", 0.1)
|
self.hparams.setdefault("insert_reduction", 0.1)
|
||||||
self.hparams.setdefault("insert_freq", 10)
|
self.hparams.setdefault("insert_freq", 10)
|
||||||
|
|
||||||
self.register_buffer(
|
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
|
||||||
"errors",
|
self.register_buffer("errors", errors)
|
||||||
torch.zeros(self.hparams.num_prototypes, device=self.device))
|
|
||||||
|
|
||||||
def training_step(self, train_batch, _batch_idx):
|
def training_step(self, train_batch, _batch_idx):
|
||||||
|
# x = train_batch
|
||||||
|
# TODO Check if the batch has labels
|
||||||
x = train_batch[0]
|
x = train_batch[0]
|
||||||
protos = self.proto_layer()
|
d = self.compute_distances(x)
|
||||||
d = self.distance_layer(x, protos)
|
|
||||||
cost, order = self.energy_layer(d)
|
cost, order = self.energy_layer(d)
|
||||||
winner = order[:, 0]
|
winner = order[:, 0]
|
||||||
mask = torch.zeros_like(d)
|
mask = torch.zeros_like(d)
|
||||||
mask[torch.arange(len(mask)), winner] = 1.0
|
mask[torch.arange(len(mask)), winner] = 1.0
|
||||||
winner_distances = d * mask
|
dp = d * mask
|
||||||
|
|
||||||
self.errors += torch.sum(winner_distances * winner_distances, dim=0)
|
self.errors += torch.sum(dp * dp, dim=0)
|
||||||
self.errors *= self.hparams.step_reduction
|
self.errors *= self.hparams.step_reduction
|
||||||
|
|
||||||
self.topology_layer(d)
|
self.topology_layer(d)
|
||||||
|
@ -140,7 +140,7 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict(mesh_input)
|
y_pred = pl_module.predict(mesh_input)
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
Loading…
Reference in New Issue
Block a user