chore: Move mixins into seperate file

This commit is contained in:
Alexander Engelsberger 2021-10-11 16:05:12 +02:00
parent a8336ee213
commit a8829945f5
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
6 changed files with 33 additions and 33 deletions

View File

@ -168,32 +168,3 @@ class SupervisedPrototypeModel(PrototypeModel):
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy) self.log("test_acc", accuracy)
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -7,8 +7,8 @@ from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ
from .mixin import ImagePrototypesMixin
class CBC(SiameseGLVQ): class CBC(SiameseGLVQ):

View File

@ -9,7 +9,8 @@ from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .abstract import SupervisedPrototypeModel
from .mixin import ImagePrototypesMixin
class GLVQ(SupervisedPrototypeModel): class GLVQ(SupervisedPrototypeModel):

View File

@ -4,8 +4,8 @@ from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from .glvq import GLVQ from .glvq import GLVQ
from .mixin import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ): class LVQ1(NonGradientMixin, GLVQ):

View File

@ -0,0 +1,27 @@
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@ -7,9 +7,10 @@ from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy from prototorch.core.losses import NeuralGasEnergy
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback from .callbacks import GNGCallback
from .extras import ConnectionTopology from .extras import ConnectionTopology
from .mixin import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):