[BUGFIX] GNG Example
This commit is contained in:
parent
0bc385fe7b
commit
47db1965ee
@ -4,7 +4,7 @@ import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
from prototorch.components.initializers import SelectionInitializer
|
||||
from prototorch.components.initializers import Zeros
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.unsupervised import GrowingNeuralGas
|
||||
from torch.utils.data import DataLoader
|
||||
@ -23,12 +23,16 @@ if __name__ == "__main__":
|
||||
train_loader = DataLoader(train_ds, batch_size=8)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(num_prototypes=5,
|
||||
lr=0.1,
|
||||
prototype_initializer=SelectionInitializer(train_ds.data))
|
||||
hparams = dict(
|
||||
num_prototypes=5,
|
||||
lr=0.1,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = GrowingNeuralGas(hparams)
|
||||
model = GrowingNeuralGas(
|
||||
hparams,
|
||||
prototype_initializer=Zeros(2),
|
||||
)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
@ -64,7 +64,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
|
Loading…
Reference in New Issue
Block a user