refactor: clean up abstract classes

This commit is contained in:
Jensun Ravichandran 2021-07-14 19:17:05 +02:00
parent 23a3683860
commit f8ad1d83eb
No known key found for this signature in database
GPG Key ID: 3331B0F18B6D4D93

View File

@ -14,20 +14,8 @@ from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
class ProtoTorchMixin(object):
pass
class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts."""
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs):
super().__init__()
@ -42,22 +30,6 @@ class PrototypeModel(ProtoTorchBolt):
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
if self.lr_scheduler is not None:
@ -75,6 +47,33 @@ class PrototypeModel(ProtoTorchBolt):
def reconfigure_optimizers(self):
self.trainer.accelerator.setup_optimizers(self.trainer)
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.reconfigure_optimizers()
@ -167,6 +166,11 @@ class SupervisedPrototypeModel(PrototypeModel):
logger=True)
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):