Refactor into shared_step
This commit is contained in:
parent
fdf9443a2c
commit
5ffbd43a7c
@ -51,12 +51,11 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
dis = self.distance_fn(x, protos)
|
distances = self.distance_fn(x, protos)
|
||||||
return dis
|
return distances
|
||||||
|
|
||||||
def log_acc(self, distances, targets, tag):
|
def log_acc(self, distances, targets, tag):
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
# Compute training accuracy
|
# Compute training accuracy
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = wtac(distances, plabels)
|
preds = wtac(distances, plabels)
|
||||||
@ -71,53 +70,36 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
logger=True)
|
logger=True)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = train_batch
|
x, y = batch
|
||||||
dis = self(x)
|
out = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
mu = self.loss(out, y, prototype_labels=plabels)
|
||||||
train_batch_loss = self.transfer_fn(mu,
|
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||||
beta=self.hparams.transfer_beta)
|
loss = batch_loss.sum(dim=0)
|
||||||
train_loss = train_batch_loss.sum(dim=0)
|
return out, loss
|
||||||
|
|
||||||
# Logging
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
self.log("train_loss", train_loss)
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(dis, y, tag="train_acc")
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
|
|
||||||
return train_loss
|
return train_loss
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
# `model.eval()` and `torch.no_grad()` are called automatically for
|
# `model.eval()` and `torch.no_grad()` handled by pl
|
||||||
# validation.
|
out, val_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
x, y = val_batch
|
|
||||||
dis = self(x)
|
|
||||||
plabels = self.proto_layer.component_labels
|
|
||||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
|
||||||
val_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
|
||||||
val_loss = val_batch_loss.sum(dim=0)
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
self.log("val_loss", val_loss)
|
self.log("val_loss", val_loss)
|
||||||
self.log_acc(dis, y, tag="val_acc")
|
self.log_acc(out, batch[-1], tag="val_acc")
|
||||||
|
|
||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
def test_step(self, test_batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
# `model.eval()` and `torch.no_grad()` are called automatically for
|
# `model.eval()` and `torch.no_grad()` handled by pl
|
||||||
# testing.
|
out, test_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
x, y = test_batch
|
|
||||||
dis = self(x)
|
|
||||||
plabels = self.proto_layer.component_labels
|
|
||||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
|
||||||
test_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
|
||||||
test_loss = test_batch_loss.sum(dim=0)
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
self.log("test_loss", test_loss)
|
|
||||||
self.log_acc(dis, y, tag="test_acc")
|
|
||||||
|
|
||||||
return test_loss
|
return test_loss
|
||||||
|
|
||||||
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||||
|
# pass
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
self.eval()
|
self.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Loading…
Reference in New Issue
Block a user