Refactor into shared_step

This commit is contained in:
Jensun Ravichandran 2021-05-19 16:57:51 +02:00
parent fdf9443a2c
commit 5ffbd43a7c

View File

@ -51,12 +51,11 @@ class GLVQ(AbstractPrototypeModel):
def forward(self, x): def forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
dis = self.distance_fn(x, protos) distances = self.distance_fn(x, protos)
return dis return distances
def log_acc(self, distances, targets, tag): def log_acc(self, distances, targets, tag):
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
# Compute training accuracy # Compute training accuracy
with torch.no_grad(): with torch.no_grad():
preds = wtac(distances, plabels) preds = wtac(distances, plabels)
@ -71,53 +70,36 @@ class GLVQ(AbstractPrototypeModel):
prog_bar=True, prog_bar=True,
logger=True) logger=True)
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = train_batch x, y = batch
dis = self(x) out = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels) mu = self.loss(out, y, prototype_labels=plabels)
train_batch_loss = self.transfer_fn(mu, batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
beta=self.hparams.transfer_beta) loss = batch_loss.sum(dim=0)
train_loss = train_batch_loss.sum(dim=0) return out, loss
# Logging def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log("train_loss", train_loss) self.log("train_loss", train_loss)
self.log_acc(dis, y, tag="train_acc") self.log_acc(out, batch[-1], tag="train_acc")
return train_loss return train_loss
def validation_step(self, val_batch, batch_idx): def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` are called automatically for # `model.eval()` and `torch.no_grad()` handled by pl
# validation. out, val_loss = self.shared_step(batch, batch_idx, optimizer_idx)
x, y = val_batch
dis = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
val_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
val_loss = val_batch_loss.sum(dim=0)
# Logging
self.log("val_loss", val_loss) self.log("val_loss", val_loss)
self.log_acc(dis, y, tag="val_acc") self.log_acc(out, batch[-1], tag="val_acc")
return val_loss return val_loss
def test_step(self, test_batch, batch_idx): def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` are called automatically for # `model.eval()` and `torch.no_grad()` handled by pl
# testing. out, test_loss = self.shared_step(batch, batch_idx, optimizer_idx)
x, y = test_batch
dis = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
test_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
test_loss = test_batch_loss.sum(dim=0)
# Logging
self.log("test_loss", test_loss)
self.log_acc(dis, y, tag="test_acc")
return test_loss return test_loss
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
def predict(self, x): def predict(self, x):
self.eval() self.eval()
with torch.no_grad(): with torch.no_grad():