chore: Move mixins into seperate file
This commit is contained in:
parent
a8336ee213
commit
a8829945f5
@ -168,32 +168,3 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||||
|
|
||||||
self.log("test_acc", accuracy)
|
self.log("test_acc", accuracy)
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin(object):
|
|
||||||
"""All mixins are ProtoTorchMixins."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NonGradientMixin(ProtoTorchMixin):
|
|
||||||
"""Mixin for custom non-gradient optimization."""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.automatic_optimization = False
|
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
|
||||||
"""Mixin for models with image prototypes."""
|
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
|
||||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
|
||||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
|
||||||
|
|
||||||
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
|
||||||
from torchvision.utils import make_grid
|
|
||||||
grid = make_grid(self.components, nrow=num_columns)
|
|
||||||
if return_channels_last:
|
|
||||||
grid = grid.permute((1, 2, 0))
|
|
||||||
return grid.cpu()
|
|
||||||
|
@ -7,8 +7,8 @@ from prototorch.core.losses import MarginLoss
|
|||||||
from prototorch.core.similarities import euclidean_similarity
|
from prototorch.core.similarities import euclidean_similarity
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
from .abstract import ImagePrototypesMixin
|
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
from .mixin import ImagePrototypesMixin
|
||||||
|
|
||||||
|
|
||||||
class CBC(SiameseGLVQ):
|
class CBC(SiameseGLVQ):
|
||||||
|
@ -9,7 +9,8 @@ from prototorch.core.transforms import LinearTransform
|
|||||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import SupervisedPrototypeModel
|
||||||
|
from .mixin import ImagePrototypesMixin
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(SupervisedPrototypeModel):
|
class GLVQ(SupervisedPrototypeModel):
|
||||||
|
@ -4,8 +4,8 @@ from prototorch.core.losses import _get_dp_dm
|
|||||||
from prototorch.nn.activations import get_activation
|
from prototorch.nn.activations import get_activation
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
from .abstract import NonGradientMixin
|
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
from .mixin import NonGradientMixin
|
||||||
|
|
||||||
|
|
||||||
class LVQ1(NonGradientMixin, GLVQ):
|
class LVQ1(NonGradientMixin, GLVQ):
|
||||||
|
27
prototorch/models/mixin.py
Normal file
27
prototorch/models/mixin.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
class ProtoTorchMixin(object):
|
||||||
|
"""All mixins are ProtoTorchMixins."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NonGradientMixin(ProtoTorchMixin):
|
||||||
|
"""Mixin for custom non-gradient optimization."""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.automatic_optimization = False
|
||||||
|
|
||||||
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||||
|
"""Mixin for models with image prototypes."""
|
||||||
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||||
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||||
|
|
||||||
|
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
grid = make_grid(self.components, nrow=num_columns)
|
||||||
|
if return_channels_last:
|
||||||
|
grid = grid.permute((1, 2, 0))
|
||||||
|
return grid.cpu()
|
@ -7,9 +7,10 @@ from prototorch.core.distances import squared_euclidean_distance
|
|||||||
from prototorch.core.losses import NeuralGasEnergy
|
from prototorch.core.losses import NeuralGasEnergy
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
from .abstract import UnsupervisedPrototypeModel
|
||||||
from .callbacks import GNGCallback
|
from .callbacks import GNGCallback
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
from .mixin import NonGradientMixin
|
||||||
|
|
||||||
|
|
||||||
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||||
|
Loading…
Reference in New Issue
Block a user