[BUG] CLI example crashes
Running examples/cli/gmlvq.py crashes with: ``` --------------------------------------------------------------------------- KeyError Traceback (most recent call last) ~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key) 249 try: --> 250 return self[key] 251 except KeyError as exp: KeyError: 'distribution' The above exception was the direct cause of the following exception: AttributeError Traceback (most recent call last) ~/work/repos/prototorch_models/examples/cli/gmlvq.py in <module> 10 11 ---> 12 cli = LightningCLI(GMLVQMNIST) ~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in __init__(self, model_class, datamodule_class, save_config_callback, trainer_class, trainer_defaults, seed_everything_default, description, env_prefix, env_parse, parser_kwargs, subclass_mode_model, subclass_mode_data) 168 seed_everything(self.config['seed_everything']) 169 self.before_instantiate_classes() --> 170 self.instantiate_classes() 171 self.prepare_fit_kwargs() 172 self.before_fit() ~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_classes(self) 211 self.config_init = self.parser.instantiate_subclasses(self.config) 212 self.instantiate_datamodule() --> 213 self.instantiate_model() 214 self.instantiate_trainer() 215 ~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_model(self) 228 self.model = self.config_init['model'] 229 else: --> 230 self.model = self.model_class(**self.config_init.get('model', {})) 231 232 def instantiate_trainer(self) -> None: ~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs) 307 def __init__(self, hparams, **kwargs): 308 distance_fn = kwargs.pop("distance_fn", omega_distance) --> 309 super().__init__(hparams, distance_fn=distance_fn, **kwargs) 310 omega = torch.randn(self.hparams.input_dim, 311 self.hparams.latent_dim, ~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs) 39 # Layers 40 self.proto_layer = LabeledComponents( ---> 41 distribution=self.hparams.distribution, 42 initializer=self.prototype_initializer(**kwargs)) 43 ~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key) 250 return self[key] 251 except KeyError as exp: --> 252 raise AttributeError(f'Missing attribute "{key}"') from exp 253 254 def __setattr__(self, key, val): AttributeError: Missing attribute "distribution" ```
This commit is contained in:
parent
86688b26b0
commit
64250d0938
@ -1,9 +1,9 @@
|
||||
model:
|
||||
hparams:
|
||||
input_dim: 784
|
||||
latent_dim: 784
|
||||
distribution:
|
||||
num_classes: 10
|
||||
prototypes_per_class: 2
|
||||
input_dim: 784
|
||||
latent_dim: 784
|
||||
proto_lr: 0.01
|
||||
bb_lr: 0.01
|
||||
|
@ -1,12 +1,12 @@
|
||||
"""GLVQ example using the MNIST dataset."""
|
||||
"""GMLVQ example using the MNIST dataset."""
|
||||
|
||||
from prototorch.models import ImageGLVQ
|
||||
from prototorch.models import ImageGMLVQ
|
||||
from prototorch.models.data import train_on_mnist
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
|
||||
class GLVQMNIST(train_on_mnist(batch_size=64), ImageGLVQ):
|
||||
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
|
||||
"""Model Definition."""
|
||||
|
||||
|
||||
cli = LightningCLI(GLVQMNIST)
|
||||
cli = LightningCLI(GMLVQMNIST)
|
||||
|
Loading…
Reference in New Issue
Block a user