Commit Graph

133 Commits

Author SHA1 Message Date
Jensun Ravichandran
b0df61d1c3 [BUGFIX] Fix examples/ng_iris.py 2021-06-03 16:34:48 +02:00
Alexander Engelsberger
47db1965ee [BUGFIX] GNG Example 2021-06-03 15:42:54 +02:00
Alexander Engelsberger
5918f1cc21 [BUGFIX] CLI example documentation improved 2021-06-03 13:47:20 +02:00
Alexander Engelsberger
3b02d99ebe [BUGFIX] Early stopping example works now 2021-06-03 13:38:16 +02:00
Jensun Ravichandran
64250d0938 [BUG] CLI example crashes
Running examples/cli/gmlvq.py crashes with:
```
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
    249         try:
--> 250             return self[key]
    251         except KeyError as exp:

KeyError: 'distribution'

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
~/work/repos/prototorch_models/examples/cli/gmlvq.py in <module>
     10
     11
---> 12 cli = LightningCLI(GMLVQMNIST)

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in __init__(self, model_class, datamodule_class, save_config_callback, trainer_class, trainer_defaults, seed_everything_default, description, env_prefix, env_parse, parser_kwargs, subclass_mode_model, subclass_mode_data)
    168             seed_everything(self.config['seed_everything'])
    169         self.before_instantiate_classes()
--> 170         self.instantiate_classes()
    171         self.prepare_fit_kwargs()
    172         self.before_fit()

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_classes(self)
    211         self.config_init = self.parser.instantiate_subclasses(self.config)
    212         self.instantiate_datamodule()
--> 213         self.instantiate_model()
    214         self.instantiate_trainer()
    215

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_model(self)
    228             self.model = self.config_init['model']
    229         else:
--> 230             self.model = self.model_class(**self.config_init.get('model', {}))
    231
    232     def instantiate_trainer(self) -> None:

~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
    307     def __init__(self, hparams, **kwargs):
    308         distance_fn = kwargs.pop("distance_fn", omega_distance)
--> 309         super().__init__(hparams, distance_fn=distance_fn, **kwargs)
    310         omega = torch.randn(self.hparams.input_dim,
    311                             self.hparams.latent_dim,

~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
     39         # Layers
     40         self.proto_layer = LabeledComponents(
---> 41             distribution=self.hparams.distribution,
     42             initializer=self.prototype_initializer(**kwargs))
     43

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
    250             return self[key]
    251         except KeyError as exp:
--> 252             raise AttributeError(f'Missing attribute "{key}"') from exp
    253
    254     def __setattr__(self, key, val):

AttributeError: Missing attribute "distribution"
```
2021-06-02 13:02:40 +02:00
Jensun Ravichandran
ef6bcc1079 [BUG] Early stopping does not seem to work
The early stopping callback does not work as expected, and crashes at the end of
max_epochs with:

```
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py in on_train_end(self)
    155         """Called when the train ends."""
    156         for callback in self.callbacks:
--> 157             callback.on_train_end(self, self.lightning_module)
    158
    159     def on_pretrain_routine_start(self) -> None:

~/work/repos/prototorch_models/prototorch/models/callbacks.py in on_train_end(self, trainer, pl_module)
     18     def on_train_end(self, trainer, pl_module):
     19         # instead, do it at the end of training loop
---> 20         self._run_early_stopping_check(trainer, pl_module)
     21
     22

TypeError: _run_early_stopping_check() takes 2 positional arguments but 3 were given
```
2021-06-02 12:44:34 +02:00
Jensun Ravichandran
bdacc83185 [REFACTOR] Update examples/dynamic_pruning.py 2021-06-02 03:53:21 +02:00
Jensun Ravichandran
a3f5d7d113 Update docstring 2021-06-02 02:40:29 +02:00
Jensun Ravichandran
b2009bb563 [FEATURE] Add example to showcase dynamic pruning of prototypes 2021-06-02 02:36:37 +02:00
Jensun Ravichandran
398431e7ea Remove examples/dynamic_components.py 2021-06-02 02:35:45 +02:00
Jensun Ravichandran
d46fe4a393 [WIP] Update CBC example 2021-06-02 00:45:33 +02:00
Jensun Ravichandran
9eb6476078 [BUG] Training unstable in examples/gng_iris.py 2021-06-02 00:21:42 +02:00
Jensun Ravichandran
7e241ff7d8 [WIP] Update examples/liramlvq_tecator.py 2021-06-02 00:02:31 +02:00
Jensun Ravichandran
930f84d3c7 Remove examples/gmlvq_iris.py 2021-06-01 23:40:15 +02:00
Alexander Engelsberger
9c1a41997b [FEATURE] Add Growing Neural Gas 2021-06-01 17:19:43 +02:00
Jensun Ravichandran
1636c84778 Rename rslvq example 2021-05-31 17:56:45 +02:00
Jensun Ravichandran
27eccf44d4 Use LambdaLayer from ProtoTorch 2021-05-31 16:53:04 +02:00
Alexander Engelsberger
8f4d66edf1 [Bugfix] Fix classcount in LIRAMLVQ example 2021-05-31 11:48:23 +02:00
Alexander Engelsberger
2a218c0ede Add example for dynamic components in callbacks 2021-05-31 11:39:24 +02:00
Alexander Engelsberger
0ac4ced85d [refactor] Use functional variant of accuracy
Prevents Accuracy in `__repr__` of the models.
2021-05-31 11:12:27 +02:00
Jensun Ravichandran
e9d2075fed Sort imports in example scripts 2021-05-31 00:52:16 +02:00
Jensun Ravichandran
7b7bc3693d Merge branch 'dev' of github.com:si-cim/prototorch_models into dev 2021-05-31 00:32:49 +02:00
Jensun Ravichandran
cd73f6c427 Add examples/dynamic_components.py 2021-05-31 00:32:27 +02:00
Alexander Engelsberger
e3392ee952 [refactor] DRY Probabilistic models 2021-05-28 17:13:06 +02:00
Alexander Engelsberger
dade502686 Add MNIST datamodule and training mixin factory. 2021-05-28 16:33:31 +02:00
Jensun Ravichandran
cc49f26b77 Remove normalization transform from cli example 2021-05-25 21:13:37 +02:00
Jensun Ravichandran
db965541fd Update example 2021-05-25 20:57:54 +02:00
Alexander Engelsberger
32d6f95db0 Add RSLVQ and LikelihoodLVQ 2021-05-25 20:26:15 +02:00
Alexander Engelsberger
72e064338c Use 'num_' in all variable names 2021-05-25 15:41:10 +02:00
Alexander Engelsberger
e7e6bf9173 Fix failing example 2021-05-21 18:54:47 +02:00
Alexander Engelsberger
2aa631f4e6 Improve example test script (with failing example) 2021-05-21 18:48:37 +02:00
Alexander Engelsberger
5b12629bd9 All examples use argparse 2021-05-21 17:55:55 +02:00
Alexander Engelsberger
b60db3174a LightningCLI Example. 2021-05-21 17:13:15 +02:00
Jensun Ravichandran
88a34a06ef [WIP] Update CBC implementation to use SiameseGLVQ 2021-05-20 17:36:00 +02:00
Jensun Ravichandran
49f9a12b5f Update mnist example 2021-05-20 17:35:07 +02:00
Jensun Ravichandran
16dc3cf4eb Update image visualization 2021-05-20 16:07:16 +02:00
Jensun Ravichandran
df061cc2ff Refactor code 2021-05-20 14:40:02 +02:00
Jensun Ravichandran
fdf9443a2c Add validation and test logic 2021-05-19 16:30:19 +02:00
Jensun Ravichandran
a14e3aa611 Add argparse to mnist example script 2021-05-18 10:17:51 +02:00
Jensun Ravichandran
00cdacf7ae Fix example script 2021-05-18 10:15:38 +02:00
Jensun Ravichandran
538256dcb7 Small changes 2021-05-17 19:37:42 +02:00
Jensun Ravichandran
d812bb0620 Update examples 2021-05-17 17:03:37 +02:00
Jensun Ravichandran
ebc42a4aa8 Set gpus=0 in examples 2021-05-15 12:43:00 +02:00
Alexander Engelsberger
0eac2ce326 Examples use GPUs if available. 2021-05-13 15:22:01 +02:00
Jensun Ravichandran
ca39aa00d5 Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
2021-05-12 16:36:22 +02:00
Jensun Ravichandran
59b8ab6643 Add knn 2021-05-11 17:22:02 +02:00
Jensun Ravichandran
2a4f184163 Update example scripts 2021-05-11 16:15:08 +02:00
Alexander Engelsberger
3fa6378c4d Add LVQ1 and LVQ2.1 Models. 2021-05-11 13:26:13 +02:00
Jensun Ravichandran
49100f43f5 Example to save and reload a model 2021-05-10 14:30:02 +02:00
Jensun Ravichandran
7d2af9c0ae Update LiRaMLVQ example script 2021-05-09 20:54:40 +02:00
Jensun Ravichandran
dd75fbfff8 Make cbc example reproducible 2021-05-07 15:46:09 +02:00
Jensun Ravichandran
728131e9db Update example scripts 2021-05-07 15:25:04 +02:00
Jensun Ravichandran
f2541acde9 Unclutter the examples folder 2021-05-07 15:21:35 +02:00
Jensun Ravichandran
e87663d10c Make siamese example script reproducible 2021-05-07 13:07:30 +02:00
Jensun Ravichandran
1b9bcf21f6 Fix typo 2021-05-06 18:50:37 +02:00
Alexander Engelsberger
4bbe73e3a9 Add GRLVQ with examples. 2021-05-06 18:42:06 +02:00
Alexander Engelsberger
79e5eaa69a Rename GMLVQ Example. 2021-05-06 18:41:50 +02:00
Alexander Engelsberger
1c3613019b Update Examples to new initializer architecture.
Visualization still borken for some examples.
2021-05-06 14:10:09 +02:00
Jensun Ravichandran
d644114090 Add loss transfer function to glvq 2021-05-04 20:56:16 +02:00
Jensun Ravichandran
f402eea884 Add GMLVQ examples 2021-05-04 15:11:16 +02:00
Jensun Ravichandran
e44516fc49 Update example script 2021-04-29 19:25:08 +02:00
Jensun Ravichandran
fef73e2fbf [BUG] NaN when training with selection initializer
How to reproduce:
Run the `glvq_spiral.py` file under `examples/`.

The error seems to occur when using a lot of prototypes in combination with the
`StratifiedSelectionInitializer`. Using only a prototype per class, or using
another initializer like the `StratifiedMeanInitializer` seems to make the
problem go away.
2021-04-29 19:09:10 +02:00
Jensun Ravichandran
8bad54fc2d Small fix on example script 2021-04-29 17:11:06 +02:00
Jensun Ravichandran
a16bebd0c4 Use Components instead of Prototypes and refactor old examples 2021-04-29 17:05:41 +02:00
Alexander Engelsberger
eeb684b3b6 GLVQ with configurable distance. 2021-04-27 15:41:44 +02:00
Jensun Ravichandran
1fb197077c Add siamese glvq 2021-04-27 14:35:17 +02:00
Jensun Ravichandran
8d57f69c9e Fix bug in visualization callback 2021-04-27 12:49:04 +02:00
Jensun Ravichandran
3148684812 Fix glvq mnist example script 2021-04-23 17:49:29 +02:00
Jensun Ravichandran
688f09ca23 Am I really the only one with OCD? 2021-04-23 17:41:29 +02:00
Jensun Ravichandran
281009ce82 Fix typo 2021-04-23 17:38:29 +02:00
Alexander Engelsberger
466bbe4c63 Add Neural Gas Model. 2021-04-23 17:30:23 +02:00
Alexander Engelsberger
c4c51a16fe Automatic Formating. 2021-04-23 17:27:47 +02:00
Alexander Engelsberger
db4499a103 Add more CBC examples. MNIST is broken. 2021-04-22 17:37:20 +02:00
Jensun Ravichandran
2e2f6707f6 Add partial cbc implementation 2021-04-22 16:01:44 +02:00
Jensun Ravichandran
55cf9b4a39 Add working glvq script as glvq_iris_v1.py 2021-04-22 12:04:56 +02:00
Jensun Ravichandran
03c5160495 Training on checkpointed model fails [BROKEN] 2021-04-22 11:56:54 +02:00
Jensun Ravichandran
fadf8c25bf Add more experimental changes
The code gets very messy very quickly as soon as serialization features are
needed.
2021-04-21 21:59:19 +02:00
Jensun Ravichandran
e5a62bd0fc Fix broken state from previous commit 2021-04-21 21:35:52 +02:00
Jensun Ravichandran
fe36e5fad9 Add partial metric/hparam features [BROKEN STATE] 2021-04-21 19:16:57 +02:00
Jensun Ravichandran
5a1ef841d3 Update mnist example 2021-04-21 16:28:20 +02:00
Jensun Ravichandran
985cdd3120 Update example scripts 2021-04-21 15:52:42 +02:00
Jensun Ravichandran
7263dfed91 Add mnist example 2021-04-21 14:54:14 +02:00
Jensun Ravichandran
984840d262 Add iris example 2021-04-21 14:54:07 +02:00