[FEATURE] Log prototype win ratios over all training batches
This commit is contained in:
parent
7743c50725
commit
8f7deb75dd
@ -45,6 +45,9 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||||
self.loss = LambdaLayer(glvq_loss)
|
self.loss = LambdaLayer(glvq_loss)
|
||||||
|
|
||||||
|
# Prototype metrics
|
||||||
|
self.initialize_prototype_win_ratios()
|
||||||
|
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
|
||||||
def prototype_initializer(self, **kwargs):
|
def prototype_initializer(self, **kwargs):
|
||||||
@ -93,6 +96,25 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
logger=True)
|
logger=True)
|
||||||
|
|
||||||
|
def initialize_prototype_win_ratios(self):
|
||||||
|
self.prototype_win_ratios = torch.zeros(self.num_prototypes,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
def log_prototype_win_ratios(self, distances):
|
||||||
|
batch_size = len(distances)
|
||||||
|
prototype_wc = torch.zeros(self.num_prototypes,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
wi, wc = torch.unique(distances.min(dim=-1).indices,
|
||||||
|
sorted=True,
|
||||||
|
return_counts=True)
|
||||||
|
prototype_wc[wi] = wc
|
||||||
|
prototype_wr = prototype_wc / batch_size
|
||||||
|
self.prototype_win_ratios = torch.vstack([
|
||||||
|
self.prototype_win_ratios,
|
||||||
|
prototype_wr,
|
||||||
|
])
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self._forward(x)
|
out = self._forward(x)
|
||||||
@ -104,6 +126,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
|
self.log_prototype_win_ratios(out)
|
||||||
self.log("train_loss", train_loss)
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(out, batch[-1], tag="train_acc")
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
return train_loss
|
return train_loss
|
||||||
@ -127,6 +150,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
test_loss += batch_loss.item()
|
test_loss += batch_loss.item()
|
||||||
self.log("test_loss", test_loss)
|
self.log("test_loss", test_loss)
|
||||||
|
|
||||||
|
# TODO
|
||||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user