[REFACTOR] Major cleanup
This commit is contained in:
parent
20471bfb1c
commit
016fcb4060
@ -4,11 +4,23 @@ from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||
from .glvq import (
|
||||
GLVQ,
|
||||
GLVQ1,
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||
from .unsupervised import KNN, GrowingNeuralGas, NeuralGas
|
||||
from .unsupervised import GrowingNeuralGas, NeuralGas
|
||||
from .vis import *
|
||||
|
||||
__version__ = "0.1.7"
|
||||
|
@ -1,7 +1,39 @@
|
||||
"""Abstract classes to be inherited by prototorch models."""
|
||||
|
||||
from typing import Final, final
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import Components, LabeledComponents
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules import WTAC, LambdaLayer
|
||||
|
||||
|
||||
class AbstractPrototypeModel(pl.LightningModule):
|
||||
class ProtoTorchBolt(pl.LightningModule):
|
||||
def __repr__(self):
|
||||
super_repr = super().__repr__()
|
||||
return f"ProtoTorch Bolt:\n{super_repr}"
|
||||
|
||||
|
||||
class PrototypeModel(ProtoTorchBolt):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("lr", 0.01)
|
||||
|
||||
# Default config
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||
|
||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
self.distance_layer = LambdaLayer(distance_fn)
|
||||
|
||||
@property
|
||||
def num_prototypes(self):
|
||||
return len(self.proto_layer.components)
|
||||
@ -28,9 +60,115 @@ class AbstractPrototypeModel(pl.LightningModule):
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
@final
|
||||
def reconfigure_optimizers(self):
|
||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||
|
||||
class PrototypeImageModel(pl.LightningModule):
|
||||
def add_prototypes(self, *args, **kwargs):
|
||||
self.proto_layer.add_components(*args, **kwargs)
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
def remove_prototypes(self, indices):
|
||||
self.proto_layer.remove_components(indices)
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
|
||||
class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Layers
|
||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||
if prototype_initializer is not None:
|
||||
self.proto_layer = Components(
|
||||
self.hparams.num_prototypes,
|
||||
initializer=prototype_initializer,
|
||||
)
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self.compute_distances(x)
|
||||
return distances
|
||||
|
||||
|
||||
class SupervisedPrototypeModel(PrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Layers
|
||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||
if prototype_initializer is not None:
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
initializer=prototype_initializer,
|
||||
)
|
||||
self.competition_layer = WTAC()
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.proto_layer.distribution)
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self.compute_distances(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
# TODO
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[
|
||||
y_pred.long()] # depends on labels {0,...,num_classes}
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
with torch.no_grad():
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = self.competition_layer(distances, plabels)
|
||||
return y_pred
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
distances = self.compute_distances(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
return y_pred
|
||||
|
||||
def log_acc(self, distances, targets, tag):
|
||||
preds = self.predict_from_distances(distances)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
self.log(tag,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
|
||||
|
||||
class NonGradientMixin():
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization: Final = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchBolt):
|
||||
"""Mixin for models with image prototypes."""
|
||||
@final
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||
|
@ -1,7 +1,12 @@
|
||||
"""Lightning Callbacks."""
|
||||
|
||||
import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from prototorch.components import Components
|
||||
|
||||
from .extras import ConnectionTopology
|
||||
|
||||
|
||||
class PruneLoserPrototypes(pl.Callback):
|
||||
@ -26,25 +31,29 @@ class PruneLoserPrototypes(pl.Callback):
|
||||
return None
|
||||
if (trainer.current_epoch + 1) % self.frequency:
|
||||
return None
|
||||
|
||||
ratios = pl_module.prototype_win_ratios.mean(dim=0)
|
||||
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
|
||||
prune_labels = pl_module.prototype_labels[to_prune.tolist()]
|
||||
to_prune = to_prune.tolist()
|
||||
prune_labels = pl_module.prototype_labels[to_prune]
|
||||
if self.prune_quota_per_epoch > 0:
|
||||
to_prune = to_prune[:self.prune_quota_per_epoch]
|
||||
prune_labels = prune_labels[:self.prune_quota_per_epoch]
|
||||
|
||||
if len(to_prune) > 0:
|
||||
if self.verbose:
|
||||
print(f"\nPrototype win ratios: {ratios}")
|
||||
print(f"Pruning prototypes at: {to_prune.tolist()}")
|
||||
print(f"Pruning prototypes at: {to_prune}")
|
||||
print(f"Corresponding labels are: {prune_labels}")
|
||||
cur_num_protos = pl_module.num_prototypes
|
||||
pl_module.remove_prototypes(indices=to_prune)
|
||||
if self.replace:
|
||||
if self.verbose:
|
||||
print(f"Re-adding prototypes at: {to_prune.tolist()}")
|
||||
labels, counts = torch.unique(prune_labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
distribution = dict(zip(labels.tolist(), counts.tolist()))
|
||||
if self.verbose:
|
||||
print(f"Re-adding pruned prototypes...")
|
||||
print(f"{distribution=}")
|
||||
pl_module.add_prototypes(distribution=distribution,
|
||||
initializer=self.initializer)
|
||||
@ -68,3 +77,58 @@ class PrototypeConvergence(pl.Callback):
|
||||
print("Stopping...")
|
||||
# TODO
|
||||
return True
|
||||
|
||||
|
||||
class GNGCallback(pl.Callback):
|
||||
"""GNG Callback.
|
||||
|
||||
Applies growing algorithm based on accumulated error and topology.
|
||||
|
||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||
|
||||
"""
|
||||
def __init__(self, reduction=0.1, freq=10):
|
||||
self.reduction = reduction
|
||||
self.freq = freq
|
||||
|
||||
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) % self.freq == 0:
|
||||
# Get information
|
||||
errors = pl_module.errors
|
||||
topology: ConnectionTopology = pl_module.topology_layer
|
||||
components: Components = pl_module.proto_layer.components
|
||||
|
||||
# Insertion point
|
||||
worst = torch.argmax(errors)
|
||||
|
||||
neighbors = topology.get_neighbors(worst)[0]
|
||||
|
||||
if len(neighbors) == 0:
|
||||
logging.log(level=20, msg="No neighbor-pairs found!")
|
||||
return
|
||||
|
||||
neighbors_errors = errors[neighbors]
|
||||
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
|
||||
|
||||
# New Prototype
|
||||
new_component = 0.5 * (components[worst] +
|
||||
components[worst_neighbor])
|
||||
|
||||
# Add component
|
||||
pl_module.proto_layer.add_components(
|
||||
initialized_components=new_component.unsqueeze(0))
|
||||
|
||||
# Adjust Topology
|
||||
topology.add_prototype()
|
||||
topology.add_connection(worst, -1)
|
||||
topology.add_connection(worst_neighbor, -1)
|
||||
topology.remove_connection(worst, worst_neighbor)
|
||||
|
||||
# New errors
|
||||
worst_error = errors[worst].unsqueeze(0)
|
||||
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
||||
pl_module.errors[worst] = errors[worst] * self.reduction
|
||||
pl_module.errors[
|
||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||
|
||||
trainer.accelerator_backend.setup_optimizers(trainer)
|
||||
|
@ -1,86 +1,18 @@
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.similarities import cosine_similarity
|
||||
|
||||
from .abstract import ImagePrototypesMixin
|
||||
from .extras import (
|
||||
CosineSimilarity,
|
||||
MarginLoss,
|
||||
ReasoningLayer,
|
||||
euclidean_similarity,
|
||||
rescaled_cosine_similarity,
|
||||
shift_activation,
|
||||
)
|
||||
from .glvq import SiameseGLVQ
|
||||
|
||||
|
||||
def rescaled_cosine_similarity(x, y):
|
||||
"""Cosine Similarity rescaled to [0, 1]."""
|
||||
similarities = cosine_similarity(x, y)
|
||||
return (similarities + 1.0) / 2.0
|
||||
|
||||
|
||||
def shift_activation(x):
|
||||
return (x + 1.0) / 2.0
|
||||
|
||||
|
||||
def euclidean_similarity(x, y, variance=1.0):
|
||||
d = euclidean_distance(x, y)
|
||||
return torch.exp(-(d * d) / (2 * variance))
|
||||
|
||||
|
||||
class CosineSimilarity(torch.nn.Module):
|
||||
def __init__(self, activation=shift_activation):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x, y):
|
||||
epsilon = torch.finfo(x.dtype).eps
|
||||
normed_x = (x / x.pow(2).sum(dim=tuple(range(
|
||||
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||
start_dim=1)
|
||||
normed_y = (y / y.pow(2).sum(dim=tuple(range(
|
||||
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||
start_dim=1)
|
||||
# normed_x = (x / torch.linalg.norm(x, dim=1))
|
||||
diss = torch.inner(normed_x, normed_y)
|
||||
return self.activation(diss)
|
||||
|
||||
|
||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
def __init__(self,
|
||||
margin=0.3,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean"):
|
||||
super().__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, input_, target):
|
||||
dp = torch.sum(target * input_, dim=-1)
|
||||
dm = torch.max(input_ - target, dim=-1).values
|
||||
return torch.nn.functional.relu(dm - dp + self.margin)
|
||||
|
||||
|
||||
class ReasoningLayer(torch.nn.Module):
|
||||
def __init__(self, num_components, num_classes, num_replicas=1):
|
||||
super().__init__()
|
||||
self.num_replicas = num_replicas
|
||||
self.num_classes = num_classes
|
||||
probabilities_init = torch.zeros(2, 1, num_components,
|
||||
self.num_classes)
|
||||
probabilities_init.uniform_(0.4, 0.6)
|
||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
pk = self.reasoning_probabilities[0]
|
||||
nk = (1 - pk) * self.reasoning_probabilities[1]
|
||||
ik = 1 - pk - nk
|
||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
||||
return img.unsqueeze(1)
|
||||
|
||||
def forward(self, detections):
|
||||
pk = self.reasoning_probabilities[0].clamp(0, 1)
|
||||
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
|
||||
numerator = (detections @ (pk - nk)) + nk.sum(1)
|
||||
probs = numerator / (pk + nk).sum(1)
|
||||
probs = probs.squeeze(0)
|
||||
return probs
|
||||
|
||||
|
||||
class CBC(SiameseGLVQ):
|
||||
"""Classification-By-Components."""
|
||||
def __init__(self,
|
||||
@ -143,10 +75,11 @@ class CBC(SiameseGLVQ):
|
||||
return y_pred
|
||||
|
||||
|
||||
class ImageCBC(CBC):
|
||||
class ImageCBC(ImagePrototypesMixin, CBC):
|
||||
"""CBC model that constrains the components to the range [0, 1] by
|
||||
clamping after updates.
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
||||
self.component_layer.components.data.clamp_(0.0, 1.0)
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
# Namespace hook
|
||||
self.proto_layer = self.component_layer
|
||||
|
142
prototorch/models/extras.py
Normal file
142
prototorch/models/extras.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""prototorch.models.extras
|
||||
|
||||
Modules not yet available in prototorch go here temporarily.
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.similarities import cosine_similarity
|
||||
|
||||
|
||||
def rescaled_cosine_similarity(x, y):
|
||||
"""Cosine Similarity rescaled to [0, 1]."""
|
||||
similarities = cosine_similarity(x, y)
|
||||
return (similarities + 1.0) / 2.0
|
||||
|
||||
|
||||
def shift_activation(x):
|
||||
return (x + 1.0) / 2.0
|
||||
|
||||
|
||||
def euclidean_similarity(x, y, variance=1.0):
|
||||
d = euclidean_distance(x, y)
|
||||
return torch.exp(-(d * d) / (2 * variance))
|
||||
|
||||
|
||||
class ConnectionTopology(torch.nn.Module):
|
||||
def __init__(self, agelimit, num_prototypes):
|
||||
super().__init__()
|
||||
self.agelimit = agelimit
|
||||
self.num_prototypes = num_prototypes
|
||||
|
||||
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
|
||||
self.age = torch.zeros_like(self.cmat)
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
|
||||
for element in order:
|
||||
i0, i1 = element[0], element[1]
|
||||
|
||||
self.cmat[i0][i1] = 1
|
||||
self.cmat[i1][i0] = 1
|
||||
|
||||
self.age[i0][i1] = 0
|
||||
self.age[i1][i0] = 0
|
||||
|
||||
self.age[i0][self.cmat[i0] == 1] += 1
|
||||
self.age[i1][self.cmat[i1] == 1] += 1
|
||||
|
||||
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
||||
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
||||
|
||||
def get_neighbors(self, position):
|
||||
return torch.where(self.cmat[position])
|
||||
|
||||
def add_prototype(self):
|
||||
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
|
||||
new_cmat[:-1, :-1] = self.cmat
|
||||
self.cmat = new_cmat
|
||||
|
||||
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
|
||||
new_age[:-1, :-1] = self.age
|
||||
self.age = new_age
|
||||
|
||||
def add_connection(self, a, b):
|
||||
self.cmat[a][b] = 1
|
||||
self.cmat[b][a] = 1
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def remove_connection(self, a, b):
|
||||
self.cmat[a][b] = 0
|
||||
self.cmat[b][a] = 0
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(agelimit): ({self.agelimit})"
|
||||
|
||||
|
||||
class CosineSimilarity(torch.nn.Module):
|
||||
def __init__(self, activation=shift_activation):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x, y):
|
||||
epsilon = torch.finfo(x.dtype).eps
|
||||
normed_x = (x / x.pow(2).sum(dim=tuple(range(
|
||||
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||
start_dim=1)
|
||||
normed_y = (y / y.pow(2).sum(dim=tuple(range(
|
||||
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
||||
start_dim=1)
|
||||
# normed_x = (x / torch.linalg.norm(x, dim=1))
|
||||
diss = torch.inner(normed_x, normed_y)
|
||||
return self.activation(diss)
|
||||
|
||||
|
||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
def __init__(self,
|
||||
margin=0.3,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction="mean"):
|
||||
super().__init__(size_average, reduce, reduction)
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, input_, target):
|
||||
dp = torch.sum(target * input_, dim=-1)
|
||||
dm = torch.max(input_ - target, dim=-1).values
|
||||
return torch.nn.functional.relu(dm - dp + self.margin)
|
||||
|
||||
|
||||
class ReasoningLayer(torch.nn.Module):
|
||||
def __init__(self, num_components, num_classes, num_replicas=1):
|
||||
super().__init__()
|
||||
self.num_replicas = num_replicas
|
||||
self.num_classes = num_classes
|
||||
probabilities_init = torch.zeros(2, 1, num_components,
|
||||
self.num_classes)
|
||||
probabilities_init.uniform_(0.4, 0.6)
|
||||
# TODO Use `self.register_parameter("param", Paramater(param))` instead
|
||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
pk = self.reasoning_probabilities[0]
|
||||
nk = (1 - pk) * self.reasoning_probabilities[1]
|
||||
ik = 1 - pk - nk
|
||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
||||
return img.unsqueeze(1)
|
||||
|
||||
def forward(self, detections):
|
||||
pk = self.reasoning_probabilities[0].clamp(0, 1)
|
||||
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
|
||||
numerator = (detections @ (pk - nk)) + nk.sum(1)
|
||||
probs = numerator / (pk + nk).sum(1)
|
||||
probs = probs.squeeze(0)
|
||||
return probs
|
@ -1,101 +1,40 @@
|
||||
"""Models based on the GLVQ framework."""
|
||||
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (
|
||||
euclidean_distance,
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
class GLVQ(SupervisedPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Defaults
|
||||
# Default hparams
|
||||
self.hparams.setdefault("transfer_fn", "identity")
|
||||
self.hparams.setdefault("transfer_beta", 10.0)
|
||||
self.hparams.setdefault("lr", 0.01)
|
||||
|
||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
|
||||
# Layers
|
||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
initializer=prototype_initializer)
|
||||
|
||||
self.distance_layer = LambdaLayer(distance_fn)
|
||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||
self.loss = LambdaLayer(glvq_loss)
|
||||
|
||||
# Loss
|
||||
self.loss = LossLayer(glvq_loss)
|
||||
|
||||
# Prototype metrics
|
||||
self.initialize_prototype_win_ratios()
|
||||
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.proto_layer.distribution)
|
||||
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
with torch.no_grad():
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = wtac(distances, plabels)
|
||||
return y_pred
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
return y_pred
|
||||
|
||||
def log_acc(self, distances, targets, tag):
|
||||
preds = self.predict_from_distances(distances)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
self.log(tag,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.register_buffer(
|
||||
"prototype_win_ratios",
|
||||
@ -121,7 +60,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self._forward(x)
|
||||
out = self.compute_distances(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(out, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
||||
@ -158,18 +97,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
# pass
|
||||
|
||||
def add_prototypes(self, initializer, distribution):
|
||||
self.proto_layer.add_components(initializer, distribution)
|
||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||
|
||||
def remove_prototypes(self, indices):
|
||||
self.proto_layer.remove_components(indices)
|
||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||
|
||||
def __repr__(self):
|
||||
super_repr = super().__repr__()
|
||||
return f"{super_repr}"
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
@ -212,7 +139,7 @@ class SiameseGLVQ(GLVQ):
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
@ -256,7 +183,7 @@ class GRLVQ(SiameseGLVQ):
|
||||
def relevance_profile(self):
|
||||
return self.relevances.detach().cpu()
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
|
||||
return distances
|
||||
@ -285,7 +212,7 @@ class SiameseGMLVQ(SiameseGLVQ):
|
||||
lam = omega.T @ omega
|
||||
return lam.detach().cpu()
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
latent_x = self.backbone(x)
|
||||
@ -305,7 +232,7 @@ class LVQMLN(SiameseGLVQ):
|
||||
rather in the embedding space.
|
||||
|
||||
"""
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
latent_protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
distances = self.distance_layer(latent_x, latent_protos)
|
||||
@ -327,7 +254,7 @@ class GMLVQ(GLVQ):
|
||||
device=self.device)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, self._omega)
|
||||
return distances
|
||||
@ -355,7 +282,7 @@ class GLVQ1(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq1_loss
|
||||
self.loss = LossLayer(lvq1_loss)
|
||||
self.optimizer = torch.optim.SGD
|
||||
|
||||
|
||||
@ -363,11 +290,11 @@ class GLVQ21(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 2.1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq21_loss
|
||||
self.loss = LossLayer(lvq21_loss)
|
||||
self.optimizer = torch.optim.SGD
|
||||
|
||||
|
||||
class ImageGLVQ(PrototypeImageModel, GLVQ):
|
||||
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
|
||||
"""GLVQ for training on image data.
|
||||
|
||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
@ -376,7 +303,7 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
|
||||
"""
|
||||
|
||||
|
||||
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
|
||||
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
||||
"""GMLVQ for training on image data.
|
||||
|
||||
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
|
38
prototorch/models/knn.py
Normal file
38
prototorch/models/knn.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""ProtoTorch KNN model."""
|
||||
|
||||
import warnings
|
||||
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.modules import KNNC
|
||||
|
||||
from .abstract import SupervisedPrototypeModel
|
||||
|
||||
|
||||
class KNN(SupervisedPrototypeModel):
|
||||
"""K-Nearest-Neighbors classification algorithm."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("k", 1)
|
||||
|
||||
data = kwargs.get("data", None)
|
||||
if data is None:
|
||||
raise ValueError("KNN requires data, but was not provided!")
|
||||
|
||||
# Layers
|
||||
self.proto_layer = LabeledComponents(initialized_components=data)
|
||||
self.competition_layer = KNNC(k=self.hparams.k)
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
return 1 # skip training step
|
||||
|
||||
def on_train_batch_start(self,
|
||||
train_batch,
|
||||
batch_idx,
|
||||
dataloader_idx=None):
|
||||
warnings.warn("k-NN has no training, skipping!")
|
||||
return -1
|
||||
|
||||
def configure_optimizers(self):
|
||||
return None
|
@ -1,34 +1,24 @@
|
||||
"""LVQ models that are optimized using non-gradient methods."""
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.losses import _get_dp_dm
|
||||
|
||||
from .abstract import NonGradientMixin
|
||||
from .glvq import GLVQ
|
||||
|
||||
|
||||
class NonGradientLVQ(GLVQ):
|
||||
"""Abstract Model for Models that do not use gradients in their update phase."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LVQ1(NonGradientLVQ):
|
||||
class LVQ1(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 1."""
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos = self.proto_layer.components
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self._forward(x)
|
||||
dis = self.compute_distances(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
d = self._forward(xi.view(1, -1))
|
||||
preds = wtac(d, plabels)
|
||||
d = self.compute_distances(xi.view(1, -1))
|
||||
preds = self.competition_layer(d, plabels)
|
||||
w = d.argmin(1)
|
||||
if yi == preds:
|
||||
shift = xi - protos[w]
|
||||
@ -45,20 +35,20 @@ class LVQ1(NonGradientLVQ):
|
||||
return None
|
||||
|
||||
|
||||
class LVQ21(NonGradientLVQ):
|
||||
class LVQ21(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 2.1."""
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
protos = self.proto_layer.components
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self._forward(x)
|
||||
dis = self.compute_distances(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
xi = xi.view(1, -1)
|
||||
yi = yi.view(1, )
|
||||
d = self._forward(xi)
|
||||
d = self.compute_distances(xi)
|
||||
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
|
||||
shiftp = xi - protos[wp]
|
||||
shiftn = protos[wn] - xi
|
||||
@ -74,5 +64,5 @@ class LVQ21(NonGradientLVQ):
|
||||
return None
|
||||
|
||||
|
||||
class MedianLVQ(NonGradientLVQ):
|
||||
class MedianLVQ(NonGradientMixin, GLVQ):
|
||||
"""Median LVQ"""
|
||||
|
@ -1,10 +1,10 @@
|
||||
"""Probabilistic GLVQ methods"""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.competitions import stratified_min, stratified_sum
|
||||
from prototorch.functions.losses import (log_likelihood_ratio_loss,
|
||||
robust_soft_loss)
|
||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
||||
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||
from prototorch.functions.transforms import gaussian
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
|
||||
from .glvq import GLVQ
|
||||
|
||||
@ -13,13 +13,16 @@ class CELVQ(GLVQ):
|
||||
"""Cross-Entropy Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Loss
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self._forward(x) # [None, num_protos]
|
||||
out = self.compute_distances(x) # [None, num_protos]
|
||||
plabels = self.proto_layer.component_labels
|
||||
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
|
||||
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
||||
probs = -1.0 * winning
|
||||
batch_loss = self.loss(probs, y.long())
|
||||
loss = batch_loss.sum(dim=0)
|
||||
return out, loss
|
||||
@ -33,14 +36,14 @@ class ProbabilisticLVQ(GLVQ):
|
||||
self.rejection_confidence = rejection_confidence
|
||||
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
distances = self.compute_distances(x)
|
||||
conditional = self.conditional_distribution(distances,
|
||||
self.hparams.variance)
|
||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||
device=self.device)
|
||||
posterior = conditional * prior
|
||||
plabels = self.proto_layer._labels
|
||||
y_pred = stratified_sum(posterior, plabels)
|
||||
y_pred = stratified_sum_pooling(posterior, plabels)
|
||||
return y_pred
|
||||
|
||||
def predict(self, x):
|
||||
@ -50,12 +53,11 @@ class ProbabilisticLVQ(GLVQ):
|
||||
return prediction
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
X, y = batch
|
||||
out = self.forward(X)
|
||||
x, y = batch
|
||||
out = self.forward(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
batch_loss = self.loss_fn(out, y, plabels)
|
||||
batch_loss = self.loss(out, y, plabels)
|
||||
loss = batch_loss.sum(dim=0)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@ -63,11 +65,11 @@ class LikelihoodRatioLVQ(ProbabilisticLVQ):
|
||||
"""Learning Vector Quantization based on Likelihood Ratios."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss_fn = log_likelihood_ratio_loss
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
"""Robust Soft Learning Vector Quantization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss_fn = robust_soft_loss
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
|
@ -8,195 +8,30 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import Components, LabeledComponents
|
||||
from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
||||
from prototorch.components.initializers import ZerosInitializer
|
||||
from prototorch.functions.competitions import knnc
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules import LambdaLayer
|
||||
from prototorch.modules.losses import NeuralGasEnergy
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
from .abstract import UnsupervisedPrototypeModel
|
||||
from .callbacks import GNGCallback
|
||||
from .extras import ConnectionTopology
|
||||
|
||||
|
||||
class GNGCallback(Callback):
|
||||
"""GNG Callback.
|
||||
|
||||
Applies growing algorithm based on accumulated error and topology.
|
||||
|
||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||
"""
|
||||
def __init__(self, reduction=0.1, freq=10):
|
||||
self.reduction = reduction
|
||||
self.freq = freq
|
||||
|
||||
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) % self.freq == 0:
|
||||
# Get information
|
||||
errors = pl_module.errors
|
||||
topology: ConnectionTopology = pl_module.topology_layer
|
||||
components: pt.components.Components = pl_module.proto_layer.components
|
||||
|
||||
# Insertion point
|
||||
worst = torch.argmax(errors)
|
||||
|
||||
neighbors = topology.get_neighbors(worst)[0]
|
||||
|
||||
if len(neighbors) == 0:
|
||||
logging.log(level=20, msg="No neighbor-pairs found!")
|
||||
return
|
||||
|
||||
neighbors_errors = errors[neighbors]
|
||||
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
|
||||
|
||||
# New Prototype
|
||||
new_component = 0.5 * (components[worst] +
|
||||
components[worst_neighbor])
|
||||
|
||||
# Add component
|
||||
pl_module.proto_layer.add_components(
|
||||
initialized_components=new_component.unsqueeze(0))
|
||||
|
||||
# Adjust Topology
|
||||
topology.add_prototype()
|
||||
topology.add_connection(worst, -1)
|
||||
topology.add_connection(worst_neighbor, -1)
|
||||
topology.remove_connection(worst, worst_neighbor)
|
||||
|
||||
# New errors
|
||||
worst_error = errors[worst].unsqueeze(0)
|
||||
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
||||
pl_module.errors[worst] = errors[worst] * self.reduction
|
||||
pl_module.errors[
|
||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||
|
||||
trainer.accelerator_backend.setup_optimizers(trainer)
|
||||
|
||||
|
||||
class ConnectionTopology(torch.nn.Module):
|
||||
def __init__(self, agelimit, num_prototypes):
|
||||
super().__init__()
|
||||
self.agelimit = agelimit
|
||||
self.num_prototypes = num_prototypes
|
||||
|
||||
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
|
||||
self.age = torch.zeros_like(self.cmat)
|
||||
|
||||
def forward(self, d):
|
||||
order = torch.argsort(d, dim=1)
|
||||
|
||||
for element in order:
|
||||
i0, i1 = element[0], element[1]
|
||||
|
||||
self.cmat[i0][i1] = 1
|
||||
self.cmat[i1][i0] = 1
|
||||
|
||||
self.age[i0][i1] = 0
|
||||
self.age[i1][i0] = 0
|
||||
|
||||
self.age[i0][self.cmat[i0] == 1] += 1
|
||||
self.age[i1][self.cmat[i1] == 1] += 1
|
||||
|
||||
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
||||
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
||||
|
||||
def get_neighbors(self, position):
|
||||
return torch.where(self.cmat[position])
|
||||
|
||||
def add_prototype(self):
|
||||
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
|
||||
new_cmat[:-1, :-1] = self.cmat
|
||||
self.cmat = new_cmat
|
||||
|
||||
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
|
||||
new_age[:-1, :-1] = self.age
|
||||
self.age = new_age
|
||||
|
||||
def add_connection(self, a, b):
|
||||
self.cmat[a][b] = 1
|
||||
self.cmat[b][a] = 1
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def remove_connection(self, a, b):
|
||||
self.cmat[a][b] = 0
|
||||
self.cmat[b][a] = 0
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(agelimit): ({self.agelimit})"
|
||||
|
||||
|
||||
class KNN(AbstractPrototypeModel):
|
||||
"""K-Nearest-Neighbors classification algorithm."""
|
||||
class NeuralGas(UnsupervisedPrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Hyperparameters
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Default Values
|
||||
self.hparams.setdefault("k", 1)
|
||||
self.hparams.setdefault("distance", euclidean_distance)
|
||||
|
||||
data = kwargs.get("data")
|
||||
x_train, y_train = parse_data_arg(data)
|
||||
|
||||
self.proto_layer = LabeledComponents(initialized_components=(x_train,
|
||||
y_train))
|
||||
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach()
|
||||
|
||||
def forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
dis = self.hparams.distance(x, protos)
|
||||
return dis
|
||||
|
||||
def predict(self, x):
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = knnc(d, plabels, k=self.hparams.k)
|
||||
return y_pred
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
return 1
|
||||
|
||||
def on_train_batch_start(self,
|
||||
train_batch,
|
||||
batch_idx,
|
||||
dataloader_idx=None):
|
||||
warnings.warn("k-NN has no training, skipping!")
|
||||
return -1
|
||||
|
||||
def configure_optimizers(self):
|
||||
return None
|
||||
|
||||
|
||||
class NeuralGas(AbstractPrototypeModel):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
|
||||
# Default Values
|
||||
# Default hparams
|
||||
self.hparams.setdefault("input_dim", 2)
|
||||
self.hparams.setdefault("agelimit", 10)
|
||||
self.hparams.setdefault("lm", 1)
|
||||
|
||||
self.proto_layer = Components(
|
||||
self.hparams.num_prototypes,
|
||||
initializer=kwargs.get("prototype_initializer"))
|
||||
|
||||
self.distance_layer = LambdaLayer(euclidean_distance)
|
||||
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
|
||||
self.topology_layer = ConnectionTopology(
|
||||
agelimit=self.hparams.agelimit,
|
||||
@ -204,9 +39,10 @@ class NeuralGas(AbstractPrototypeModel):
|
||||
)
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
# x = train_batch
|
||||
# TODO Check if the batch has labels
|
||||
x = train_batch[0]
|
||||
protos = self.proto_layer()
|
||||
d = self.distance_layer(x, protos)
|
||||
d = self.compute_distances(x)
|
||||
cost, _ = self.energy_layer(d)
|
||||
self.topology_layer(d)
|
||||
return cost
|
||||
@ -216,26 +52,26 @@ class GrowingNeuralGas(NeuralGas):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# defaults
|
||||
# Defaults
|
||||
self.hparams.setdefault("step_reduction", 0.5)
|
||||
self.hparams.setdefault("insert_reduction", 0.1)
|
||||
self.hparams.setdefault("insert_freq", 10)
|
||||
|
||||
self.register_buffer(
|
||||
"errors",
|
||||
torch.zeros(self.hparams.num_prototypes, device=self.device))
|
||||
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
|
||||
self.register_buffer("errors", errors)
|
||||
|
||||
def training_step(self, train_batch, _batch_idx):
|
||||
# x = train_batch
|
||||
# TODO Check if the batch has labels
|
||||
x = train_batch[0]
|
||||
protos = self.proto_layer()
|
||||
d = self.distance_layer(x, protos)
|
||||
d = self.compute_distances(x)
|
||||
cost, order = self.energy_layer(d)
|
||||
winner = order[:, 0]
|
||||
mask = torch.zeros_like(d)
|
||||
mask[torch.arange(len(mask)), winner] = 1.0
|
||||
winner_distances = d * mask
|
||||
dp = d * mask
|
||||
|
||||
self.errors += torch.sum(winner_distances * winner_distances, dim=0)
|
||||
self.errors += torch.sum(dp * dp, dim=0)
|
||||
self.errors *= self.hparams.step_reduction
|
||||
|
||||
self.topology_layer(d)
|
||||
|
@ -140,7 +140,7 @@ class VisGLVQ2D(Vis2DAbstract):
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
Loading…
Reference in New Issue
Block a user