prototorch_models/prototorch/models/abstract.py

24 lines
767 B
Python
Raw Normal View History

2021-04-29 17:14:33 +00:00
import pytorch_lightning as pl
import torch
2021-05-03 11:20:49 +00:00
from torch.optim.lr_scheduler import ExponentialLR
2021-04-29 17:14:33 +00:00
class AbstractLightningModel(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
2021-05-03 11:20:49 +00:00
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
2021-04-29 17:14:33 +00:00
class AbstractPrototypeModel(AbstractLightningModel):
@property
def prototypes(self):
return self.proto_layer.components.detach().numpy()