[FEATURE] Log prototype win ratios over all training batches
This commit is contained in:
parent
7743c50725
commit
8f7deb75dd
@ -45,6 +45,9 @@ class GLVQ(AbstractPrototypeModel):
|
||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||
self.loss = LambdaLayer(glvq_loss)
|
||||
|
||||
# Prototype metrics
|
||||
self.initialize_prototype_win_ratios()
|
||||
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
|
||||
def prototype_initializer(self, **kwargs):
|
||||
@ -93,6 +96,25 @@ class GLVQ(AbstractPrototypeModel):
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.prototype_win_ratios = torch.zeros(self.num_prototypes,
|
||||
device=self.device)
|
||||
|
||||
def log_prototype_win_ratios(self, distances):
|
||||
batch_size = len(distances)
|
||||
prototype_wc = torch.zeros(self.num_prototypes,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
wi, wc = torch.unique(distances.min(dim=-1).indices,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
prototype_wc[wi] = wc
|
||||
prototype_wr = prototype_wc / batch_size
|
||||
self.prototype_win_ratios = torch.vstack([
|
||||
self.prototype_win_ratios,
|
||||
prototype_wr,
|
||||
])
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self._forward(x)
|
||||
@ -104,6 +126,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
self.log_prototype_win_ratios(out)
|
||||
self.log("train_loss", train_loss)
|
||||
self.log_acc(out, batch[-1], tag="train_acc")
|
||||
return train_loss
|
||||
@ -127,6 +150,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
test_loss += batch_loss.item()
|
||||
self.log("test_loss", test_loss)
|
||||
|
||||
# TODO
|
||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
# pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user