chore: move mixins to seperate file

This commit is contained in:
Alexander Engelsberger 2022-05-17 16:19:47 +02:00
parent d16a0de202
commit e0b92e9ac2
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
7 changed files with 122 additions and 82 deletions

View File

@ -22,7 +22,16 @@ from prototorch.nn.wrappers import LambdaLayer
class ProtoTorchBolt(pl.LightningModule): class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts.""" """All ProtoTorch models are ProtoTorch Bolts.
hparams:
- lr: learning rate
kwargs:
- optimizer: optimizer class
- lr_scheduler: learning rate scheduler class
- lr_scheduler_kwargs: learning rate scheduler kwargs
"""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
@ -65,6 +74,11 @@ class ProtoTorchBolt(pl.LightningModule):
class PrototypeModel(ProtoTorchBolt): class PrototypeModel(ProtoTorchBolt):
"""Abstract Prototype Model
kwargs:
- distance_fn: distance function
"""
proto_layer: AbstractComponents proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
@ -203,35 +217,3 @@ class SupervisedPrototypeModel(PrototypeModel):
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy) self.log("test_acc", accuracy)
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -40,8 +40,8 @@ class PruneLoserPrototypes(pl.Callback):
return None return None
ratios = pl_module.prototype_win_ratios.mean(dim=0) ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune = torch.arange(len(ratios))[ratios < self.threshold] to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold]
to_prune = to_prune.tolist() to_prune = to_prune_tensor.tolist()
prune_labels = pl_module.prototype_labels[to_prune] prune_labels = pl_module.prototype_labels[to_prune]
if self.prune_quota_per_epoch > 0: if self.prune_quota_per_epoch > 0:
to_prune = to_prune[:self.prune_quota_per_epoch] to_prune = to_prune[:self.prune_quota_per_epoch]

View File

@ -1,4 +1,5 @@
import torch import torch
import torch.nn.functional as F
import torchmetrics import torchmetrics
from prototorch.core.competitions import CBCC from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents from prototorch.core.components import ReasoningComponents
@ -7,12 +8,13 @@ from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ
from .mixins import ImagePrototypesMixin
class CBC(SiameseGLVQ): class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
proto_layer: ReasoningComponents
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, skip_proto_layer=True, **kwargs) super().__init__(hparams, skip_proto_layer=True, **kwargs)
@ -22,7 +24,7 @@ class CBC(SiameseGLVQ):
reasonings_initializer = kwargs.get("reasonings_initializer", reasonings_initializer = kwargs.get("reasonings_initializer",
RandomReasoningsInitializer()) RandomReasoningsInitializer())
self.components_layer = ReasoningComponents( self.components_layer = ReasoningComponents(
self.hparams.distribution, self.hparams["distribution"],
components_initializer=components_initializer, components_initializer=components_initializer,
reasonings_initializer=reasonings_initializer, reasonings_initializer=reasonings_initializer,
) )
@ -32,7 +34,7 @@ class CBC(SiameseGLVQ):
# Namespace hook # Namespace hook
self.proto_layer = self.components_layer self.proto_layer = self.components_layer
self.loss = MarginLoss(self.hparams.margin) self.loss = MarginLoss(self.hparams["margin"])
def forward(self, x): def forward(self, x):
components, reasonings = self.components_layer() components, reasonings = self.components_layer()
@ -48,7 +50,7 @@ class CBC(SiameseGLVQ):
x, y = batch x, y = batch
y_pred = self(x) y_pred = self(x)
num_classes = self.num_classes num_classes = self.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes) y_true = F.one_hot(y.long(), num_classes=num_classes)
loss = self.loss(y_pred, y_true).mean() loss = self.loss(y_pred, y_true).mean()
return y_pred, loss return y_pred, loss

View File

@ -17,8 +17,9 @@ from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .abstract import SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization from .extras import ltangent_distance, orthogonalization
from .mixins import ImagePrototypesMixin
class GLVQ(SupervisedPrototypeModel): class GLVQ(SupervisedPrototypeModel):
@ -46,19 +47,24 @@ class GLVQ(SupervisedPrototypeModel):
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device)) torch.zeros(self.num_prototypes, device=self.device),
)
def on_train_epoch_start(self): def on_train_epoch_start(self):
self.initialize_prototype_win_ratios() self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances): def log_prototype_win_ratios(self, distances):
batch_size = len(distances) batch_size = len(distances)
prototype_wc = torch.zeros(self.num_prototypes, prototype_wc = torch.zeros(
self.num_prototypes,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device,
wi, wc = torch.unique(distances.min(dim=-1).indices, )
wi, wc = torch.unique(
distances.min(dim=-1).indices,
sorted=True, sorted=True,
return_counts=True) return_counts=True,
)
prototype_wc[wi] = wc prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([ self.prototype_win_ratios = torch.vstack([
@ -81,14 +87,12 @@ class GLVQ(SupervisedPrototypeModel):
return train_loss return train_loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, val_loss = self.shared_step(batch, batch_idx) out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss) self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc") self.log_acc(out, batch[-1], tag="val_acc")
return val_loss return val_loss
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx) out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc") self.log_acc(out, batch[-1], tag="test_acc")
return test_loss return test_loss
@ -99,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel):
test_loss += batch_loss.item() test_loss += batch_loss.item()
self.log("test_loss", test_loss) self.log("test_loss", test_loss)
# TODO
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
class SiameseGLVQ(GLVQ): class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting. """GLVQ in a Siamese setting.
@ -113,19 +113,23 @@ class SiameseGLVQ(GLVQ):
""" """
def __init__(self, def __init__(
self,
hparams, hparams,
backbone=torch.nn.Identity(), backbone=torch.nn.Identity(),
both_path_gradients=False, both_path_gradients=False,
**kwargs): **kwargs,
):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone self.backbone = backbone
self.both_path_gradients = both_path_gradients self.both_path_gradients = both_path_gradients
def configure_optimizers(self): def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(), proto_opt = self.optimizer(
lr=self.hparams["proto_lr"]) self.proto_layer.parameters(),
lr=self.hparams["proto_lr"],
)
# Only add a backbone optimizer if backbone has trainable parameters # Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters()) bb_params = list(self.backbone.parameters())
if (bb_params): if (bb_params):
@ -266,13 +270,19 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get(
EyeLinearTransformInitializer()) "omega_initializer",
omega = omega_initializer.generate(self.hparams["input_dim"], EyeLinearTransformInitializer(),
self.hparams["latent_dim"]) )
omega = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega, self.backbone = LambdaLayer(
name="omega matrix") lambda x: x @ self._omega,
name="omega matrix",
)
@property @property
def omega_matrix(self): def omega_matrix(self):

View File

@ -1,20 +1,21 @@
"""LVQ models that are optimized using non-gradient methods.""" """LVQ models that are optimized using non-gradient methods."""
import logging import logging
from collections import OrderedDict
from prototorch.core.losses import _get_dp_dm from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from .glvq import GLVQ from .glvq import GLVQ
from .mixins import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ): class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plables = self.proto_layer() protos, plabels = self.proto_layer()
x, y = train_batch x, y = train_batch
dis = self.compute_distances(x) dis = self.compute_distances(x)
# TODO Vectorized implementation # TODO Vectorized implementation
@ -28,9 +29,11 @@ class LVQ1(NonGradientMixin, GLVQ):
else: else:
shift = protos[w] - xi shift = protos[w] - xi
updated_protos = protos + 0.0 updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams.lr * shift) updated_protos[w] = protos[w] + (self.hparams["lr"] * shift)
self.proto_layer.load_state_dict({"_components": updated_protos}, self.proto_layer.load_state_dict(
strict=False) OrderedDict(_components=updated_protos),
strict=False,
)
logging.debug(f"dis={dis}") logging.debug(f"dis={dis}")
logging.debug(f"y={y}") logging.debug(f"y={y}")
@ -58,10 +61,12 @@ class LVQ21(NonGradientMixin, GLVQ):
shiftp = xi - protos[wp] shiftp = xi - protos[wp]
shiftn = protos[wn] - xi shiftn = protos[wn] - xi
updated_protos = protos + 0.0 updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp) updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn) updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn)
self.proto_layer.load_state_dict({"_components": updated_protos}, self.proto_layer.load_state_dict(
strict=False) OrderedDict(_components=updated_protos),
strict=False,
)
# Logging # Logging
self.log_acc(dis, y, tag="train_acc") self.log_acc(dis, y, tag="train_acc")
@ -80,14 +85,17 @@ class MedianLVQ(NonGradientMixin, GLVQ):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer( self.transfer_layer = LambdaLayer(
get_activation(self.hparams.transfer_fn)) get_activation(self.hparams["transfer_fn"]))
def _f(self, x, y, protos, plabels): def _f(self, x, y, protos, plabels):
d = self.distance_layer(x, protos) d = self.distance_layer(x, protos)
dp, dm = _get_dp_dm(d, y, plabels) dp, dm = _get_dp_dm(d, y, plabels, with_indices=False)
mu = (dp - dm) / (dp + dm) mu = (dp - dm) / (dp + dm)
invmu = -1.0 * mu negative_mu = -1.0 * mu
f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0 f = self.transfer_layer(
negative_mu,
beta=self.hparams["transfer_beta"],
) + 1.0
return f return f
def expectation(self, x, y, protos, plabels): def expectation(self, x, y, protos, plabels):
@ -118,8 +126,10 @@ class MedianLVQ(NonGradientMixin, GLVQ):
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma) _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound: if _lower_bound > lower_bound:
logging.debug(f"Updating prototype {i} to data {k}...") logging.debug(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict({"_components": _protos}, self.proto_layer.load_state_dict(
strict=False) OrderedDict(_components=_protos),
strict=False,
)
break break
# Logging # Logging

View File

@ -0,0 +1,35 @@
import pytorch_lightning as pl
import torch
from prototorch.core.components import Components
class ProtoTorchMixin(pl.LightningModule):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -6,9 +6,10 @@ from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy from prototorch.core.losses import NeuralGasEnergy
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback from .callbacks import GNGCallback
from .extras import ConnectionTopology from .extras import ConnectionTopology
from .mixins import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):