refactor: clean up abstract classes
This commit is contained in:
parent
23a3683860
commit
f8ad1d83eb
@ -14,20 +14,8 @@ from ..core.pooling import stratified_min_pooling
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
pass
|
||||
|
||||
|
||||
class ProtoTorchBolt(pl.LightningModule):
|
||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||
def __repr__(self):
|
||||
surep = super().__repr__()
|
||||
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||
wrapped = f"ProtoTorch Bolt(\n{indented})"
|
||||
return wrapped
|
||||
|
||||
|
||||
class PrototypeModel(ProtoTorchBolt):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
@ -42,22 +30,6 @@ class PrototypeModel(ProtoTorchBolt):
|
||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||
|
||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
self.distance_layer = LambdaLayer(distance_fn)
|
||||
|
||||
@property
|
||||
def num_prototypes(self):
|
||||
return len(self.proto_layer.components)
|
||||
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.proto_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Only an alias for the prototypes."""
|
||||
return self.prototypes
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||
if self.lr_scheduler is not None:
|
||||
@ -75,6 +47,33 @@ class PrototypeModel(ProtoTorchBolt):
|
||||
def reconfigure_optimizers(self):
|
||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||
|
||||
def __repr__(self):
|
||||
surep = super().__repr__()
|
||||
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||
wrapped = f"ProtoTorch Bolt(\n{indented})"
|
||||
return wrapped
|
||||
|
||||
|
||||
class PrototypeModel(ProtoTorchBolt):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
self.distance_layer = LambdaLayer(distance_fn)
|
||||
|
||||
@property
|
||||
def num_prototypes(self):
|
||||
return len(self.proto_layer.components)
|
||||
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.proto_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Only an alias for the prototypes."""
|
||||
return self.prototypes
|
||||
|
||||
def add_prototypes(self, *args, **kwargs):
|
||||
self.proto_layer.add_components(*args, **kwargs)
|
||||
self.reconfigure_optimizers()
|
||||
@ -167,6 +166,11 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
logger=True)
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
pass
|
||||
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user