refactor: clean up abstract classes
This commit is contained in:
parent
23a3683860
commit
f8ad1d83eb
@ -14,20 +14,8 @@ from ..core.pooling import stratified_min_pooling
|
|||||||
from ..nn.wrappers import LambdaLayer
|
from ..nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchBolt(pl.LightningModule):
|
class ProtoTorchBolt(pl.LightningModule):
|
||||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||||
def __repr__(self):
|
|
||||||
surep = super().__repr__()
|
|
||||||
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
|
||||||
wrapped = f"ProtoTorch Bolt(\n{indented})"
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
class PrototypeModel(ProtoTorchBolt):
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -42,22 +30,6 @@ class PrototypeModel(ProtoTorchBolt):
|
|||||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||||
|
|
||||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
|
||||||
self.distance_layer = LambdaLayer(distance_fn)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_prototypes(self):
|
|
||||||
return len(self.proto_layer.components)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prototypes(self):
|
|
||||||
return self.proto_layer.components.detach().cpu()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def components(self):
|
|
||||||
"""Only an alias for the prototypes."""
|
|
||||||
return self.prototypes
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||||
if self.lr_scheduler is not None:
|
if self.lr_scheduler is not None:
|
||||||
@ -75,6 +47,33 @@ class PrototypeModel(ProtoTorchBolt):
|
|||||||
def reconfigure_optimizers(self):
|
def reconfigure_optimizers(self):
|
||||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
surep = super().__repr__()
|
||||||
|
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||||
|
wrapped = f"ProtoTorch Bolt(\n{indented})"
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class PrototypeModel(ProtoTorchBolt):
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
|
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||||
|
self.distance_layer = LambdaLayer(distance_fn)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_prototypes(self):
|
||||||
|
return len(self.proto_layer.components)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
return self.proto_layer.components.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self):
|
||||||
|
"""Only an alias for the prototypes."""
|
||||||
|
return self.prototypes
|
||||||
|
|
||||||
def add_prototypes(self, *args, **kwargs):
|
def add_prototypes(self, *args, **kwargs):
|
||||||
self.proto_layer.add_components(*args, **kwargs)
|
self.proto_layer.add_components(*args, **kwargs)
|
||||||
self.reconfigure_optimizers()
|
self.reconfigure_optimizers()
|
||||||
@ -167,6 +166,11 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
logger=True)
|
logger=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ProtoTorchMixin(object):
|
||||||
|
"""All mixins are ProtoTorchMixins."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NonGradientMixin(ProtoTorchMixin):
|
class NonGradientMixin(ProtoTorchMixin):
|
||||||
"""Mixin for custom non-gradient optimization."""
|
"""Mixin for custom non-gradient optimization."""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user