[BUGFIX] Fix examples/ng_iris.py

This commit is contained in:
Jensun Ravichandran 2021-06-03 16:34:48 +02:00
parent 47db1965ee
commit b0df61d1c3

View File

@ -2,13 +2,12 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
@ -31,7 +30,8 @@ if __name__ == "__main__":
hparams = dict(num_prototypes=30, lr=0.03)
# Initialize the model
model = pt.models.NeuralGas(hparams)
model = pt.models.NeuralGas(hparams,
prototype_initializer=pt.components.Zeros(2))
# Model summary
print(model)