Add more experimental changes

The code gets very messy very quickly as soon as serialization features are
needed.
This commit is contained in:
Jensun Ravichandran 2021-04-21 21:59:19 +02:00
parent e5a62bd0fc
commit fadf8c25bf
2 changed files with 11 additions and 32 deletions

View File

@ -25,6 +25,8 @@ class GLVQIris(GLVQ):
parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-1) parser.add_argument("--lr", type=float, default=1e-1)
parser.add_argument("--batch_size", type=int, default=150) parser.add_argument("--batch_size", type=int, default=150)
parser.add_argument("--input_dim", type=int, default=2)
parser.add_argument("--nclasses", type=int, default=3)
parser.add_argument("--prototypes_per_class", type=int, default=3) parser.add_argument("--prototypes_per_class", type=int, default=3)
parser.add_argument("--prototype_initializer", parser.add_argument("--prototype_initializer",
type=str, type=str,
@ -101,6 +103,7 @@ if __name__ == "__main__":
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
parser, parser,
max_epochs=10,
callbacks=[ callbacks=[
vis, # comment this line out to disable the visualization vis, # comment this line out to disable the visualization
], ],
@ -109,12 +112,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
args = parser.parse_args() args = parser.parse_args()
model = GLVQIris( model = GLVQIris(args, data=[x_train, y_train])
args,
input_dim=x_train.shape[1],
nclasses=3,
data=[x_train, y_train],
)
# Model summary # Model summary
print(model) print(model)

View File

@ -12,20 +12,14 @@ from prototorch.modules.prototypes import Prototypes1D
class GLVQ(pl.LightningModule): class GLVQ(pl.LightningModule):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
def __init__(self, hparams, input_dim, nclasses, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
self.lr = hparams.lr self.save_hyperparameters(hparams)
self.hparams = hparams
# self.save_hyperparameters(
# "lr",
# "prototypes_per_class",
# "prototype_initializer",
# )
self.proto_layer = Prototypes1D( self.proto_layer = Prototypes1D(
input_dim=input_dim, input_dim=self.hparams.input_dim,
nclasses=nclasses, nclasses=self.hparams.nclasses,
prototypes_per_class=hparams.prototypes_per_class, prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=hparams.prototype_initializer, prototype_initializer=self.hparams.prototype_initializer,
**kwargs) **kwargs)
self.train_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy()
@ -38,22 +32,9 @@ class GLVQ(pl.LightningModule):
return self.proto_layer.prototype_labels.detach().numpy() return self.proto_layer.prototype_labels.detach().numpy()
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer return optimizer
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser],
add_help=False)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-2)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--prototypes_per_class", type=int, default=1)
parser.add_argument("--prototype_initializer",
type=str,
default="zeros")
return parser
def forward(self, x): def forward(self, x):
protos = self.proto_layer.prototypes protos = self.proto_layer.prototypes
dis = euclidean_distance(x, protos) dis = euclidean_distance(x, protos)