[BUGFIX] GNG Example
This commit is contained in:
parent
0bc385fe7b
commit
47db1965ee
@ -4,7 +4,7 @@ import argparse
|
|||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from prototorch.components.initializers import SelectionInitializer
|
from prototorch.components.initializers import Zeros
|
||||||
from prototorch.datasets import Iris
|
from prototorch.datasets import Iris
|
||||||
from prototorch.models.unsupervised import GrowingNeuralGas
|
from prototorch.models.unsupervised import GrowingNeuralGas
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -23,12 +23,16 @@ if __name__ == "__main__":
|
|||||||
train_loader = DataLoader(train_ds, batch_size=8)
|
train_loader = DataLoader(train_ds, batch_size=8)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(num_prototypes=5,
|
hparams = dict(
|
||||||
lr=0.1,
|
num_prototypes=5,
|
||||||
prototype_initializer=SelectionInitializer(train_ds.data))
|
lr=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = GrowingNeuralGas(hparams)
|
model = GrowingNeuralGas(
|
||||||
|
hparams,
|
||||||
|
prototype_initializer=Zeros(2),
|
||||||
|
)
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
print(model)
|
print(model)
|
||||||
|
@ -64,7 +64,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self._forward(x)
|
distances = self._forward(x)
|
||||||
y_pred = self.predict_from_distances(distances)
|
y_pred = self.predict_from_distances(distances)
|
||||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
|
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
def predict_from_distances(self, distances):
|
def predict_from_distances(self, distances):
|
||||||
|
Loading…
Reference in New Issue
Block a user