Add missing abstract.py file

This commit is contained in:
Jensun Ravichandran 2021-04-29 19:14:33 +02:00
parent fef73e2fbf
commit ccaa52c408

View File

@ -0,0 +1,14 @@
import pytorch_lightning as pl
import torch
class AbstractLightningModel(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
class AbstractPrototypeModel(AbstractLightningModel):
@property
def prototypes(self):
return self.proto_layer.components.detach().numpy()