[BUGFIX] GNG Example
This commit is contained in:
@@ -4,7 +4,7 @@ import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
from prototorch.components.initializers import SelectionInitializer
|
||||
from prototorch.components.initializers import Zeros
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.unsupervised import GrowingNeuralGas
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -23,12 +23,16 @@ if __name__ == "__main__":
|
||||
train_loader = DataLoader(train_ds, batch_size=8)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(num_prototypes=5,
|
||||
lr=0.1,
|
||||
prototype_initializer=SelectionInitializer(train_ds.data))
|
||||
hparams = dict(
|
||||
num_prototypes=5,
|
||||
lr=0.1,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = GrowingNeuralGas(hparams)
|
||||
model = GrowingNeuralGas(
|
||||
hparams,
|
||||
prototype_initializer=Zeros(2),
|
||||
)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
Reference in New Issue
Block a user