chore: Move mixins into seperate file
This commit is contained in:
parent
a8336ee213
commit
a8829945f5
@ -168,32 +168,3 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
|
||||
self.log("test_acc", accuracy)
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
pass
|
||||
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||
from torchvision.utils import make_grid
|
||||
grid = make_grid(self.components, nrow=num_columns)
|
||||
if return_channels_last:
|
||||
grid = grid.permute((1, 2, 0))
|
||||
return grid.cpu()
|
||||
|
@ -7,8 +7,8 @@ from prototorch.core.losses import MarginLoss
|
||||
from prototorch.core.similarities import euclidean_similarity
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
from .abstract import ImagePrototypesMixin
|
||||
from .glvq import SiameseGLVQ
|
||||
from .mixin import ImagePrototypesMixin
|
||||
|
||||
|
||||
class CBC(SiameseGLVQ):
|
||||
|
@ -9,7 +9,8 @@ from prototorch.core.transforms import LinearTransform
|
||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
from .abstract import SupervisedPrototypeModel
|
||||
from .mixin import ImagePrototypesMixin
|
||||
|
||||
|
||||
class GLVQ(SupervisedPrototypeModel):
|
||||
|
@ -4,8 +4,8 @@ from prototorch.core.losses import _get_dp_dm
|
||||
from prototorch.nn.activations import get_activation
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
from .abstract import NonGradientMixin
|
||||
from .glvq import GLVQ
|
||||
from .mixin import NonGradientMixin
|
||||
|
||||
|
||||
class LVQ1(NonGradientMixin, GLVQ):
|
||||
|
27
prototorch/models/mixin.py
Normal file
27
prototorch/models/mixin.py
Normal file
@ -0,0 +1,27 @@
|
||||
class ProtoTorchMixin(object):
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
pass
|
||||
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||
from torchvision.utils import make_grid
|
||||
grid = make_grid(self.components, nrow=num_columns)
|
||||
if return_channels_last:
|
||||
grid = grid.permute((1, 2, 0))
|
||||
return grid.cpu()
|
@ -7,9 +7,10 @@ from prototorch.core.distances import squared_euclidean_distance
|
||||
from prototorch.core.losses import NeuralGasEnergy
|
||||
from prototorch.nn.wrappers import LambdaLayer
|
||||
|
||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||
from .abstract import UnsupervisedPrototypeModel
|
||||
from .callbacks import GNGCallback
|
||||
from .extras import ConnectionTopology
|
||||
from .mixin import NonGradientMixin
|
||||
|
||||
|
||||
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
|
Loading…
Reference in New Issue
Block a user