diff --git a/examples/dynamic_components.py b/examples/dynamic_components.py index 7b29e11..6653a31 100644 --- a/examples/dynamic_components.py +++ b/examples/dynamic_components.py @@ -47,17 +47,25 @@ if __name__ == "__main__": prototype_initializer=pt.components.SMI(train_ds), ) + # Summary + print(model) + # Callbacks vis = pt.models.VisGLVQ2D(train_ds) proto_scheduler = PrototypeScheduler(train_ds, 10) # Setup trainer - trainer = pl.Trainer.from_argparse_args(args, - max_epochs=100, - callbacks=[vis, proto_scheduler], - terminate_on_nan=True, - weights_summary=None, - accelerator='ddp') + trainer = pl.Trainer.from_argparse_args( + args, + max_epochs=100, + callbacks=[ + vis, + proto_scheduler, + ], + terminate_on_nan=True, + weights_summary=None, + accelerator="ddp", + ) # Training loop trainer.fit(model, train_loader) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 3e4baa5..b2605e7 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -9,23 +9,11 @@ from prototorch.functions.distances import (euclidean_distance, omega_distance, sed) from prototorch.functions.helper import get_flat from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss +from prototorch.modules import LambdaLayer from .abstract import AbstractPrototypeModel, PrototypeImageModel -class FunctionLayer(torch.nn.Module): - def __init__(self, distance_fn): - super().__init__() - self.fn = distance_fn - self.name = distance_fn.__name__ - - def forward(self, *args, **kwargs): - return self.fn(*args, **kwargs) - - def extra_repr(self): - return self.name - - class GLVQ(AbstractPrototypeModel): """Generalized Learning Vector Quantization.""" def __init__(self, hparams, **kwargs): @@ -46,9 +34,9 @@ class GLVQ(AbstractPrototypeModel): distribution=self.hparams.distribution, initializer=self.prototype_initializer(**kwargs)) - self.distance_layer = FunctionLayer(distance_fn) - self.transfer_layer = FunctionLayer(tranfer_fn) - self.loss = FunctionLayer(glvq_loss) + self.distance_layer = LambdaLayer(distance_fn) + self.transfer_layer = LambdaLayer(tranfer_fn) + self.loss = LambdaLayer(glvq_loss) self.optimizer = kwargs.get("optimizer", torch.optim.Adam)