[BUG] CLI example crashes

Running examples/cli/gmlvq.py crashes with:
```
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
    249         try:
--> 250             return self[key]
    251         except KeyError as exp:

KeyError: 'distribution'

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
~/work/repos/prototorch_models/examples/cli/gmlvq.py in <module>
     10
     11
---> 12 cli = LightningCLI(GMLVQMNIST)

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in __init__(self, model_class, datamodule_class, save_config_callback, trainer_class, trainer_defaults, seed_everything_default, description, env_prefix, env_parse, parser_kwargs, subclass_mode_model, subclass_mode_data)
    168             seed_everything(self.config['seed_everything'])
    169         self.before_instantiate_classes()
--> 170         self.instantiate_classes()
    171         self.prepare_fit_kwargs()
    172         self.before_fit()

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_classes(self)
    211         self.config_init = self.parser.instantiate_subclasses(self.config)
    212         self.instantiate_datamodule()
--> 213         self.instantiate_model()
    214         self.instantiate_trainer()
    215

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_model(self)
    228             self.model = self.config_init['model']
    229         else:
--> 230             self.model = self.model_class(**self.config_init.get('model', {}))
    231
    232     def instantiate_trainer(self) -> None:

~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
    307     def __init__(self, hparams, **kwargs):
    308         distance_fn = kwargs.pop("distance_fn", omega_distance)
--> 309         super().__init__(hparams, distance_fn=distance_fn, **kwargs)
    310         omega = torch.randn(self.hparams.input_dim,
    311                             self.hparams.latent_dim,

~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
     39         # Layers
     40         self.proto_layer = LabeledComponents(
---> 41             distribution=self.hparams.distribution,
     42             initializer=self.prototype_initializer(**kwargs))
     43

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
    250             return self[key]
    251         except KeyError as exp:
--> 252             raise AttributeError(f'Missing attribute "{key}"') from exp
    253
    254     def __setattr__(self, key, val):

AttributeError: Missing attribute "distribution"
```
This commit is contained in:
Jensun Ravichandran 2021-06-02 13:02:40 +02:00
parent 86688b26b0
commit 64250d0938
2 changed files with 8 additions and 8 deletions

View File

@ -1,9 +1,9 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
hparams:
distribution:
num_classes: 10
prototypes_per_class: 2
input_dim: 784
latent_dim: 784
proto_lr: 0.01
bb_lr: 0.01

View File

@ -1,12 +1,12 @@
"""GLVQ example using the MNIST dataset."""
"""GMLVQ example using the MNIST dataset."""
from prototorch.models import ImageGLVQ
from prototorch.models import ImageGMLVQ
from prototorch.models.data import train_on_mnist
from pytorch_lightning.utilities.cli import LightningCLI
class GLVQMNIST(train_on_mnist(batch_size=64), ImageGLVQ):
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
"""Model Definition."""
cli = LightningCLI(GLVQMNIST)
cli = LightningCLI(GMLVQMNIST)