chore: move mixins to seperate file

This commit is contained in:
Alexander Engelsberger 2022-05-17 16:19:47 +02:00
parent d16a0de202
commit e0b92e9ac2
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
7 changed files with 122 additions and 82 deletions

View File

@ -22,7 +22,16 @@ from prototorch.nn.wrappers import LambdaLayer
class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts."""
"""All ProtoTorch models are ProtoTorch Bolts.
hparams:
- lr: learning rate
kwargs:
- optimizer: optimizer class
- lr_scheduler: learning rate scheduler class
- lr_scheduler_kwargs: learning rate scheduler kwargs
"""
def __init__(self, hparams, **kwargs):
super().__init__()
@ -65,6 +74,11 @@ class ProtoTorchBolt(pl.LightningModule):
class PrototypeModel(ProtoTorchBolt):
"""Abstract Prototype Model
kwargs:
- distance_fn: distance function
"""
proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs):
@ -203,35 +217,3 @@ class SupervisedPrototypeModel(PrototypeModel):
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy)
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -40,8 +40,8 @@ class PruneLoserPrototypes(pl.Callback):
return None
ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
to_prune = to_prune.tolist()
to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold]
to_prune = to_prune_tensor.tolist()
prune_labels = pl_module.prototype_labels[to_prune]
if self.prune_quota_per_epoch > 0:
to_prune = to_prune[:self.prune_quota_per_epoch]

View File

@ -1,4 +1,5 @@
import torch
import torch.nn.functional as F
import torchmetrics
from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents
@ -7,12 +8,13 @@ from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ
from .mixins import ImagePrototypesMixin
class CBC(SiameseGLVQ):
"""Classification-By-Components."""
proto_layer: ReasoningComponents
def __init__(self, hparams, **kwargs):
super().__init__(hparams, skip_proto_layer=True, **kwargs)
@ -22,7 +24,7 @@ class CBC(SiameseGLVQ):
reasonings_initializer = kwargs.get("reasonings_initializer",
RandomReasoningsInitializer())
self.components_layer = ReasoningComponents(
self.hparams.distribution,
self.hparams["distribution"],
components_initializer=components_initializer,
reasonings_initializer=reasonings_initializer,
)
@ -32,7 +34,7 @@ class CBC(SiameseGLVQ):
# Namespace hook
self.proto_layer = self.components_layer
self.loss = MarginLoss(self.hparams.margin)
self.loss = MarginLoss(self.hparams["margin"])
def forward(self, x):
components, reasonings = self.components_layer()
@ -48,7 +50,7 @@ class CBC(SiameseGLVQ):
x, y = batch
y_pred = self(x)
num_classes = self.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
y_true = F.one_hot(y.long(), num_classes=num_classes)
loss = self.loss(y_pred, y_true).mean()
return y_pred, loss

View File

@ -17,8 +17,9 @@ from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .abstract import SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
from .mixins import ImagePrototypesMixin
class GLVQ(SupervisedPrototypeModel):
@ -46,19 +47,24 @@ class GLVQ(SupervisedPrototypeModel):
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device))
torch.zeros(self.num_prototypes, device=self.device),
)
def on_train_epoch_start(self):
self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances):
batch_size = len(distances)
prototype_wc = torch.zeros(self.num_prototypes,
prototype_wc = torch.zeros(
self.num_prototypes,
dtype=torch.long,
device=self.device)
wi, wc = torch.unique(distances.min(dim=-1).indices,
device=self.device,
)
wi, wc = torch.unique(
distances.min(dim=-1).indices,
sorted=True,
return_counts=True)
return_counts=True,
)
prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([
@ -81,14 +87,12 @@ class GLVQ(SupervisedPrototypeModel):
return train_loss
def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc")
return val_loss
def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc")
return test_loss
@ -99,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel):
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
# TODO
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
@ -113,19 +113,23 @@ class SiameseGLVQ(GLVQ):
"""
def __init__(self,
def __init__(
self,
hparams,
backbone=torch.nn.Identity(),
both_path_gradients=False,
**kwargs):
**kwargs,
):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone
self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams["proto_lr"])
proto_opt = self.optimizer(
self.proto_layer.parameters(),
lr=self.hparams["proto_lr"],
)
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
@ -266,13 +270,19 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams["latent_dim"])
omega_initializer = kwargs.get(
"omega_initializer",
EyeLinearTransformInitializer(),
)
omega = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
self.backbone = LambdaLayer(
lambda x: x @ self._omega,
name="omega matrix",
)
@property
def omega_matrix(self):

View File

@ -1,20 +1,21 @@
"""LVQ models that are optimized using non-gradient methods."""
import logging
from collections import OrderedDict
from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from .glvq import GLVQ
from .mixins import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plables = self.proto_layer()
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
# TODO Vectorized implementation
@ -28,9 +29,11 @@ class LVQ1(NonGradientMixin, GLVQ):
else:
shift = protos[w] - xi
updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams.lr * shift)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
updated_protos[w] = protos[w] + (self.hparams["lr"] * shift)
self.proto_layer.load_state_dict(
OrderedDict(_components=updated_protos),
strict=False,
)
logging.debug(f"dis={dis}")
logging.debug(f"y={y}")
@ -58,10 +61,12 @@ class LVQ21(NonGradientMixin, GLVQ):
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn)
self.proto_layer.load_state_dict(
OrderedDict(_components=updated_protos),
strict=False,
)
# Logging
self.log_acc(dis, y, tag="train_acc")
@ -80,14 +85,17 @@ class MedianLVQ(NonGradientMixin, GLVQ):
super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer(
get_activation(self.hparams.transfer_fn))
get_activation(self.hparams["transfer_fn"]))
def _f(self, x, y, protos, plabels):
d = self.distance_layer(x, protos)
dp, dm = _get_dp_dm(d, y, plabels)
dp, dm = _get_dp_dm(d, y, plabels, with_indices=False)
mu = (dp - dm) / (dp + dm)
invmu = -1.0 * mu
f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0
negative_mu = -1.0 * mu
f = self.transfer_layer(
negative_mu,
beta=self.hparams["transfer_beta"],
) + 1.0
return f
def expectation(self, x, y, protos, plabels):
@ -118,8 +126,10 @@ class MedianLVQ(NonGradientMixin, GLVQ):
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound:
logging.debug(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict({"_components": _protos},
strict=False)
self.proto_layer.load_state_dict(
OrderedDict(_components=_protos),
strict=False,
)
break
# Logging

View File

@ -0,0 +1,35 @@
import pytorch_lightning as pl
import torch
from prototorch.core.components import Components
class ProtoTorchMixin(pl.LightningModule):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -6,9 +6,10 @@ from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
from .mixins import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):