feat: Add basic GLVQ with new architecture

This commit is contained in:
Alexander Engelsberger
2021-10-14 15:49:12 +02:00
parent d4448f2bc9
commit 967953442b
7 changed files with 433 additions and 6 deletions

View File

@@ -3,6 +3,7 @@
import argparse
import prototorch as pt
import prototorch.models.expanded
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
@@ -29,7 +30,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.GLVQ(
model = prototorch.models.expanded.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),