fix: examples/ng_iris.py

This commit is contained in:
Alexander Engelsberger 2021-06-21 14:59:54 +02:00
parent 72404f7c4e
commit 2b2e4a5f37
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3

View File

@ -2,14 +2,13 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -38,7 +37,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.NeuralGas( model = pt.models.NeuralGas(
hparams, hparams,
prototype_initializer=pt.components.Zeros(2), prototypes_initializer=pt.core.ZCI(2),
lr_scheduler=ExponentialLR, lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False), lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
) )