[BUGFIX] Fix examples/ng_iris.py
This commit is contained in:
parent
47db1965ee
commit
b0df61d1c3
@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -31,7 +30,8 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(num_prototypes=30, lr=0.03)
|
hparams = dict(num_prototypes=30, lr=0.03)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.NeuralGas(hparams)
|
model = pt.models.NeuralGas(hparams,
|
||||||
|
prototype_initializer=pt.components.Zeros(2))
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
print(model)
|
print(model)
|
||||||
|
Loading…
Reference in New Issue
Block a user