Update SOM

This commit is contained in:
Jensun Ravichandran
2021-06-09 18:21:12 +02:00
parent 022d791ea5
commit 57f8bec270
3 changed files with 17 additions and 6 deletions

View File

@@ -10,7 +10,12 @@ from prototorch.functions.distances import euclidean_distance
from prototorch.modules import WTAC, LambdaLayer
class ProtoTorchMixin(object):
pass
class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts."""
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
@@ -160,7 +165,7 @@ class SupervisedPrototypeModel(PrototypeModel):
logger=True)
class NonGradientMixin():
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -170,7 +175,7 @@ class NonGradientMixin():
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchBolt):
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
@final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):