Refactor into shared_step

This commit is contained in:
Jensun Ravichandran 2021-05-19 16:57:51 +02:00
parent fdf9443a2c
commit 5ffbd43a7c

View File

@ -51,12 +51,11 @@ class GLVQ(AbstractPrototypeModel):
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.distance_fn(x, protos)
return dis
distances = self.distance_fn(x, protos)
return distances
def log_acc(self, distances, targets, tag):
plabels = self.proto_layer.component_labels
# Compute training accuracy
with torch.no_grad():
preds = wtac(distances, plabels)
@ -71,53 +70,36 @@ class GLVQ(AbstractPrototypeModel):
prog_bar=True,
logger=True)
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
x, y = train_batch
dis = self(x)
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
train_batch_loss = self.transfer_fn(mu,
beta=self.hparams.transfer_beta)
train_loss = train_batch_loss.sum(dim=0)
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
return out, loss
# Logging
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log("train_loss", train_loss)
self.log_acc(dis, y, tag="train_acc")
self.log_acc(out, batch[-1], tag="train_acc")
return train_loss
def validation_step(self, val_batch, batch_idx):
# `model.eval()` and `torch.no_grad()` are called automatically for
# validation.
x, y = val_batch
dis = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
val_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
val_loss = val_batch_loss.sum(dim=0)
# Logging
def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, val_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log("val_loss", val_loss)
self.log_acc(dis, y, tag="val_acc")
self.log_acc(out, batch[-1], tag="val_acc")
return val_loss
def test_step(self, test_batch, batch_idx):
# `model.eval()` and `torch.no_grad()` are called automatically for
# testing.
x, y = test_batch
dis = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
test_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
test_loss = test_batch_loss.sum(dim=0)
# Logging
self.log("test_loss", test_loss)
self.log_acc(dis, y, tag="test_acc")
def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx, optimizer_idx)
return test_loss
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
def predict(self, x):
self.eval()
with torch.no_grad():