feat: distribute GMLVQ into mixins

This commit is contained in:
Alexander Engelsberger
2022-05-31 17:56:03 +02:00
parent e922aae432
commit 23d1a71b31
14 changed files with 211 additions and 152 deletions

View File

@@ -2,12 +2,12 @@ import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.models.y_arch.callbacks import (
from prototorch.y_arch.callbacks import (
LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from prototorch.y_arch.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
@@ -39,8 +39,7 @@ if __name__ == "__main__":
# Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters(
lr=0.1,
backbone_lr=5,
lr=dict(components_layer=0.1, _omega=0),
input_dim=4,
distribution=dict(
num_classes=3,