Update SOM
This commit is contained in:
@@ -10,7 +10,12 @@ from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules import WTAC, LambdaLayer
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
pass
|
||||
|
||||
|
||||
class ProtoTorchBolt(pl.LightningModule):
|
||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||
def __repr__(self):
|
||||
surep = super().__repr__()
|
||||
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||
@@ -160,7 +165,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
logger=True)
|
||||
|
||||
|
||||
class NonGradientMixin():
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -170,7 +175,7 @@ class NonGradientMixin():
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchBolt):
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
@final
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
|
Reference in New Issue
Block a user