[BUGFIX] GNG Example

This commit is contained in:
Alexander Engelsberger 2021-06-03 15:42:54 +02:00
parent 0bc385fe7b
commit 47db1965ee
2 changed files with 10 additions and 6 deletions

View File

@ -4,7 +4,7 @@ import argparse
import prototorch as pt
import pytorch_lightning as pl
from prototorch.components.initializers import SelectionInitializer
from prototorch.components.initializers import Zeros
from prototorch.datasets import Iris
from prototorch.models.unsupervised import GrowingNeuralGas
from torch.utils.data import DataLoader
@ -23,12 +23,16 @@ if __name__ == "__main__":
train_loader = DataLoader(train_ds, batch_size=8)
# Hyperparameters
hparams = dict(num_prototypes=5,
lr=0.1,
prototype_initializer=SelectionInitializer(train_ds.data))
hparams = dict(
num_prototypes=5,
lr=0.1,
)
# Initialize the model
model = GrowingNeuralGas(hparams)
model = GrowingNeuralGas(
hparams,
prototype_initializer=Zeros(2),
)
# Model summary
print(model)

View File

@ -64,7 +64,7 @@ class GLVQ(AbstractPrototypeModel):
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
return y_pred
def predict_from_distances(self, distances):