[BUG] CLI example crashes
Running examples/cli/gmlvq.py crashes with:
```
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
249 try:
--> 250 return self[key]
251 except KeyError as exp:
KeyError: 'distribution'
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
~/work/repos/prototorch_models/examples/cli/gmlvq.py in <module>
10
11
---> 12 cli = LightningCLI(GMLVQMNIST)
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in __init__(self, model_class, datamodule_class, save_config_callback, trainer_class, trainer_defaults, seed_everything_default, description, env_prefix, env_parse, parser_kwargs, subclass_mode_model, subclass_mode_data)
168 seed_everything(self.config['seed_everything'])
169 self.before_instantiate_classes()
--> 170 self.instantiate_classes()
171 self.prepare_fit_kwargs()
172 self.before_fit()
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_classes(self)
211 self.config_init = self.parser.instantiate_subclasses(self.config)
212 self.instantiate_datamodule()
--> 213 self.instantiate_model()
214 self.instantiate_trainer()
215
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_model(self)
228 self.model = self.config_init['model']
229 else:
--> 230 self.model = self.model_class(**self.config_init.get('model', {}))
231
232 def instantiate_trainer(self) -> None:
~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
307 def __init__(self, hparams, **kwargs):
308 distance_fn = kwargs.pop("distance_fn", omega_distance)
--> 309 super().__init__(hparams, distance_fn=distance_fn, **kwargs)
310 omega = torch.randn(self.hparams.input_dim,
311 self.hparams.latent_dim,
~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
39 # Layers
40 self.proto_layer = LabeledComponents(
---> 41 distribution=self.hparams.distribution,
42 initializer=self.prototype_initializer(**kwargs))
43
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
250 return self[key]
251 except KeyError as exp:
--> 252 raise AttributeError(f'Missing attribute "{key}"') from exp
253
254 def __setattr__(self, key, val):
AttributeError: Missing attribute "distribution"
```
2021-06-02 11:02:40 +00:00
|
|
|
"""GMLVQ example using the MNIST dataset."""
|
2021-05-21 15:10:36 +00:00
|
|
|
|
2021-06-16 14:16:34 +00:00
|
|
|
import torch
|
2021-05-21 15:10:36 +00:00
|
|
|
from pytorch_lightning.utilities.cli import LightningCLI
|
|
|
|
|
2021-06-16 14:16:34 +00:00
|
|
|
import prototorch as pt
|
|
|
|
from prototorch.models import ImageGMLVQ
|
|
|
|
from prototorch.models.abstract import PrototypeModel
|
|
|
|
from prototorch.models.data import MNISTDataModule
|
|
|
|
|
2021-05-21 15:10:36 +00:00
|
|
|
|
2021-06-16 14:16:34 +00:00
|
|
|
class ExperimentClass(ImageGMLVQ):
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams,
|
|
|
|
optimizer=torch.optim.Adam,
|
|
|
|
prototype_initializer=pt.components.zeros(28 * 28),
|
|
|
|
**kwargs)
|
2021-05-21 15:10:36 +00:00
|
|
|
|
|
|
|
|
2021-06-16 14:16:34 +00:00
|
|
|
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
|