Fix broken state from previous commit

This commit is contained in:
Jensun Ravichandran
2021-04-21 21:35:52 +02:00
parent fa7b178028
commit e5a62bd0fc
2 changed files with 77 additions and 49 deletions

View File

@@ -1,3 +1,5 @@
import argparse
import pytorch_lightning as pl
import torch
import torchmetrics
@@ -10,10 +12,21 @@ from prototorch.modules.prototypes import Prototypes1D
class GLVQ(pl.LightningModule):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams):
def __init__(self, hparams, input_dim, nclasses, **kwargs):
super().__init__()
self.lr = hparams.lr
self.proto_layer = Prototypes1D(**kwargs)
self.hparams = hparams
# self.save_hyperparameters(
# "lr",
# "prototypes_per_class",
# "prototype_initializer",
# )
self.proto_layer = Prototypes1D(
input_dim=input_dim,
nclasses=nclasses,
prototypes_per_class=hparams.prototypes_per_class,
prototype_initializer=hparams.prototype_initializer,
**kwargs)
self.train_acc = torchmetrics.Accuracy()
@property
@@ -24,15 +37,28 @@ class GLVQ(pl.LightningModule):
def prototype_labels(self):
return self.proto_layer.prototype_labels.detach().numpy()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser],
add_help=False)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-2)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--prototypes_per_class", type=int, default=1)
parser.add_argument("--prototype_initializer",
type=str,
default="zeros")
return parser
def forward(self, x):
protos = self.proto_layer.prototypes
dis = euclidean_distance(x, protos)
return dis
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
@@ -44,8 +70,13 @@ class GLVQ(pl.LightningModule):
with torch.no_grad():
preds = wtac(dis, plabels)
# self.train_acc.update(preds.int(), y.int())
self.train_acc(preds.int(), y.int()) # FloatTensors are assumed to be class probabilities
self.log("Training Accuracy", self.train_acc, on_step=False, on_epoch=True)
self.train_acc(
preds.int(),
y.int()) # FloatTensors are assumed to be class probabilities
self.log("Training Accuracy",
self.train_acc,
on_step=False,
on_epoch=True)
return loss
# def training_epoch_end(self, outs):