[FEATURE] Log prototype win ratios over all training batches

This commit is contained in:
Jensun Ravichandran 2021-06-02 02:32:54 +02:00
parent 7743c50725
commit 8f7deb75dd

View File

@ -45,6 +45,9 @@ class GLVQ(AbstractPrototypeModel):
self.transfer_layer = LambdaLayer(transfer_fn)
self.loss = LambdaLayer(glvq_loss)
# Prototype metrics
self.initialize_prototype_win_ratios()
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
def prototype_initializer(self, **kwargs):
@ -93,6 +96,25 @@ class GLVQ(AbstractPrototypeModel):
prog_bar=True,
logger=True)
def initialize_prototype_win_ratios(self):
self.prototype_win_ratios = torch.zeros(self.num_prototypes,
device=self.device)
def log_prototype_win_ratios(self, distances):
batch_size = len(distances)
prototype_wc = torch.zeros(self.num_prototypes,
dtype=torch.long,
device=self.device)
wi, wc = torch.unique(distances.min(dim=-1).indices,
sorted=True,
return_counts=True)
prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([
self.prototype_win_ratios,
prototype_wr,
])
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x)
@ -104,6 +126,7 @@ class GLVQ(AbstractPrototypeModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log_prototype_win_ratios(out)
self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc")
return train_loss
@ -127,6 +150,7 @@ class GLVQ(AbstractPrototypeModel):
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
# TODO
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass