prototorch_models/prototorch/models/abstract.py

42 lines
1.3 KiB
Python
Raw Normal View History

2021-04-29 17:14:33 +00:00
import pytorch_lightning as pl
2021-05-03 11:20:49 +00:00
from torch.optim.lr_scheduler import ExponentialLR
2021-04-29 17:14:33 +00:00
2021-05-11 14:13:00 +00:00
class AbstractPrototypeModel(pl.LightningModule):
2021-06-01 15:44:10 +00:00
@property
def num_prototypes(self):
return len(self.proto_layer.components)
2021-05-11 14:13:00 +00:00
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
2021-04-29 17:14:33 +00:00
def configure_optimizers(self):
2021-05-11 14:13:00 +00:00
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
2021-05-03 11:20:49 +00:00
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class PrototypeImageModel(pl.LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)
2021-05-25 13:41:10 +00:00
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
2021-05-21 11:11:48 +00:00
from torchvision.utils import make_grid
2021-05-25 13:41:10 +00:00
grid = make_grid(self.components, nrow=num_columns)
2021-05-21 11:11:48 +00:00
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()