[BUGFIX] GNG Example

This commit is contained in:
Alexander Engelsberger 2021-06-03 15:42:54 +02:00
parent 0bc385fe7b
commit 47db1965ee
2 changed files with 10 additions and 6 deletions

View File

@ -4,7 +4,7 @@ import argparse
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
from prototorch.components.initializers import SelectionInitializer from prototorch.components.initializers import Zeros
from prototorch.datasets import Iris from prototorch.datasets import Iris
from prototorch.models.unsupervised import GrowingNeuralGas from prototorch.models.unsupervised import GrowingNeuralGas
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -23,12 +23,16 @@ if __name__ == "__main__":
train_loader = DataLoader(train_ds, batch_size=8) train_loader = DataLoader(train_ds, batch_size=8)
# Hyperparameters # Hyperparameters
hparams = dict(num_prototypes=5, hparams = dict(
num_prototypes=5,
lr=0.1, lr=0.1,
prototype_initializer=SelectionInitializer(train_ds.data)) )
# Initialize the model # Initialize the model
model = GrowingNeuralGas(hparams) model = GrowingNeuralGas(
hparams,
prototype_initializer=Zeros(2),
)
# Model summary # Model summary
print(model) print(model)

View File

@ -64,7 +64,7 @@ class GLVQ(AbstractPrototypeModel):
def forward(self, x): def forward(self, x):
distances = self._forward(x) distances = self._forward(x)
y_pred = self.predict_from_distances(distances) y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()] y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
return y_pred return y_pred
def predict_from_distances(self, distances): def predict_from_distances(self, distances):