[FEATURE] Add PLVQ model

This commit is contained in:
Alexander Engelsberger
2021-06-08 15:01:08 +02:00
committed by Alexander Engelsberger
parent fc11d78b38
commit c87ed5ba8b
7 changed files with 61 additions and 32 deletions

View File

@@ -5,6 +5,7 @@ import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision.transforms import Lambda
if __name__ == "__main__":
# Command-line arguments
@@ -24,12 +25,15 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[2, 2, 3],
lr=0.05,
variance=0.1,
proto_lr=0.05,
lambd=0.1,
input_dim=2,
latent_dim=2,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.probabilistic.RSLVQ(
model = pt.models.probabilistic.PLVQ(
hparams,
optimizer=torch.optim.Adam,
# prototype_initializer=pt.components.SMI(train_ds),
@@ -45,7 +49,7 @@ if __name__ == "__main__":
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
vis = pt.models.VisSiameseGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(