Fix broken state from previous commit
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import argparse
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
@@ -10,10 +12,21 @@ from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
class GLVQ(pl.LightningModule):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, hparams):
|
||||
def __init__(self, hparams, input_dim, nclasses, **kwargs):
|
||||
super().__init__()
|
||||
self.lr = hparams.lr
|
||||
self.proto_layer = Prototypes1D(**kwargs)
|
||||
self.hparams = hparams
|
||||
# self.save_hyperparameters(
|
||||
# "lr",
|
||||
# "prototypes_per_class",
|
||||
# "prototype_initializer",
|
||||
# )
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=input_dim,
|
||||
nclasses=nclasses,
|
||||
prototypes_per_class=hparams.prototypes_per_class,
|
||||
prototype_initializer=hparams.prototype_initializer,
|
||||
**kwargs)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
@property
|
||||
@@ -24,15 +37,28 @@ class GLVQ(pl.LightningModule):
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.prototype_labels.detach().numpy()
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optimizer
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = argparse.ArgumentParser(parents=[parent_parser],
|
||||
add_help=False)
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--lr", type=float, default=1e-2)
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--prototypes_per_class", type=int, default=1)
|
||||
parser.add_argument("--prototype_initializer",
|
||||
type=str,
|
||||
default="zeros")
|
||||
return parser
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.proto_layer.prototypes
|
||||
dis = euclidean_distance(x, protos)
|
||||
return dis
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optimizer
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
x, y = train_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
@@ -44,8 +70,13 @@ class GLVQ(pl.LightningModule):
|
||||
with torch.no_grad():
|
||||
preds = wtac(dis, plabels)
|
||||
# self.train_acc.update(preds.int(), y.int())
|
||||
self.train_acc(preds.int(), y.int()) # FloatTensors are assumed to be class probabilities
|
||||
self.log("Training Accuracy", self.train_acc, on_step=False, on_epoch=True)
|
||||
self.train_acc(
|
||||
preds.int(),
|
||||
y.int()) # FloatTensors are assumed to be class probabilities
|
||||
self.log("Training Accuracy",
|
||||
self.train_acc,
|
||||
on_step=False,
|
||||
on_epoch=True)
|
||||
return loss
|
||||
|
||||
# def training_epoch_end(self, outs):
|
||||
|
Reference in New Issue
Block a user