266 Commits

Author SHA1 Message Date
Alexander Engelsberger
ed83138e1f build: bump version 1.0.0a3 → 1.0.0a4 2022-06-12 11:52:06 +02:00
Alexander Engelsberger
1be7d7ec09 fix: dont save component initializer as hparm 2022-06-12 11:40:33 +02:00
Alexander Engelsberger
60d2a1d2c9 fix: dont save prototype initializer in yarch checkpoint 2022-06-12 11:12:55 +02:00
Alexander Engelsberger
be7d7f43bd fix: fix problems with y architecture and checkpoint 2022-06-12 10:36:15 +02:00
Alexander Engelsberger
fe729781fc build: bump version 1.0.0a2 → 1.0.0a3 2022-06-09 14:59:07 +02:00
Alexander Engelsberger
a7df7be1c8 feat: add confusion matrix callback 2022-06-09 14:55:59 +02:00
Alexander Engelsberger
696719600b build: bump version 1.0.0a1 → 1.0.0a2 2022-06-03 11:52:50 +02:00
Alexander Engelsberger
48e7c029fa fix: Fix __init__.py 2022-06-03 11:40:45 +02:00
Alexander Engelsberger
5de3a480c7 build: bump version 0.5.2 → 1.0.0a1 2022-06-03 11:07:10 +02:00
Alexander Engelsberger
626f51ce80 ci: Add possible prerelease to bumpversion 2022-06-03 11:06:44 +02:00
Alexander Engelsberger
6d7d93c8e8 chore: rename y_arch to y 2022-06-03 10:39:11 +02:00
Jensun Ravichandran
93b1d0bd46 feat(vis): add flag to save visualization frames 2022-06-02 19:55:03 +02:00
Alexander Engelsberger
b7992c01db fix: apply hotfix 2022-06-01 14:26:37 +02:00
Alexander Engelsberger
fcd944d3ff build: bump version 0.5.1 → 0.5.2 2022-06-01 14:25:44 +02:00
Alexander Engelsberger
054720dd7b fix(hotfix): Protobuf error workaround 2022-06-01 14:14:57 +02:00
Alexander Engelsberger
23d1a71b31 feat: distribute GMLVQ into mixins 2022-05-31 17:56:03 +02:00
Alexander Engelsberger
e922aae432 feat: add GMLVQ with new architecture 2022-05-19 16:13:08 +02:00
Alexander Engelsberger
3e50d0d817 chore(protoy): mixin restructuring 2022-05-18 15:43:09 +02:00
Alexander Engelsberger
dc4f31d700 chore: rename clc-lc to proto-Y-architecture 2022-05-18 14:11:46 +02:00
Alexander Engelsberger
02954044d7 chore: improve clc-lc test 2022-05-17 17:26:03 +02:00
Alexander Engelsberger
8f08ba66ea feat: copy old clc-lc implementation 2022-05-17 16:25:43 +02:00
Alexander Engelsberger
e0b92e9ac2 chore: move mixins to seperate file 2022-05-17 16:19:47 +02:00
Alexander Engelsberger
d16a0de202 build: bump version 0.5.0 → 0.5.1 2022-05-17 12:04:08 +02:00
Alexander Engelsberger
76fea3f881 chore: update all examples to pytorch 1.6 2022-05-17 12:03:43 +02:00
Alexander Engelsberger
c00513ae0d chore: minor updates and version updates 2022-05-17 12:00:52 +02:00
Alexander Engelsberger
bccef8bef0 chore: replace relative imports 2022-05-16 11:12:53 +02:00
Alexander Engelsberger
29ee326b85 ci: Update PreCommit hooks 2022-05-16 11:11:48 +02:00
Jensun Ravichandran
055568dc86 fix: glvq_iris example works again 2022-05-09 17:33:52 +02:00
Alexander Engelsberger
3a7328e290 chore: small changes 2022-04-27 10:37:12 +02:00
Alexander Engelsberger
d6629c8792 build: bump version 0.4.1 → 0.5.0 2022-04-27 10:28:06 +02:00
Alexander Engelsberger
ef65bd3789 chore: update prototorch dependency 2022-04-27 09:50:48 +02:00
Alexander Engelsberger
d096eba2c9 chore: update pytorch lightning dependency 2022-04-27 09:39:00 +02:00
Alexander Engelsberger
dd34c57e2e ci: fix github action python version 2022-04-27 09:30:07 +02:00
Alexander Engelsberger
5911f4dd90 chore: fix errors for pytorch_lightning>1.6 2022-04-27 09:25:42 +02:00
Alexander Engelsberger
dbfe315f4f ci: add python 3.10 to tests 2022-04-27 09:24:34 +02:00
Jensun Ravichandran
9c90c902dc fix: correct typo 2022-04-04 21:54:04 +02:00
Jensun Ravichandran
7d3f59e54b test: add unit tests 2022-03-30 15:12:33 +02:00
Jensun Ravichandran
9da47b1dba fix: CBC example works again 2022-03-30 15:10:06 +02:00
Alexander Engelsberger
41f0e77fc9 fix: siameseGLVQ checks requires_grad of backbone
Necessary for different optimizer runs
2022-03-29 17:08:40 +02:00
Jensun Ravichandran
fab786a07e fix: rename hparam output_dimlatent_dim in SiameseGMLVQ 2022-03-29 15:24:42 +02:00
Jensun Ravichandran
40bd7ed380 docs: update tutorial notebook 2022-03-29 15:04:05 +02:00
Jensun Ravichandran
4941c2b89d feat: data argument optional in some visualizers 2022-03-29 11:26:22 +02:00
Jensun Ravichandran
ce14dec7e9 feat: add VisSpectralProtos 2022-03-10 15:24:44 +01:00
Jensun Ravichandran
b31c8cc707 feat: add xlabel and ylabel arguments to visualizers 2022-03-09 13:59:19 +01:00
Jensun Ravichandran
e21e6c7e02 docs: update tutorial notebook 2022-02-15 14:38:53 +01:00
Jensun Ravichandran
dd696ea1e0 fix: update hparams.distribution as it changes during training 2022-02-02 21:53:03 +01:00
Jensun Ravichandran
15e7232747 fix: ignore prototype_win_ratios by loading with strict=False 2022-02-02 21:52:01 +01:00
Jensun Ravichandran
197b728c63 feat: add visualize method to visualization callbacks
All visualization callbacks now contain a `visualize` method that takes an
appropriate PyTorchLightning Module and visualizes it without the need for a
Trainer. This is to encourage users to perform one-off visualizations after
training.
2022-02-02 21:45:44 +01:00
Jensun Ravichandran
98892afee0 chore: add example for saving/loading models from checkpoints 2022-02-02 19:02:26 +01:00
Alexander Engelsberger
d5855dbe97 fix: GLVQ can now be restored from checkpoint 2022-02-02 16:17:11 +01:00
Alexander Engelsberger
75a39f5b03 build: bump version 0.4.0 → 0.4.1 2022-01-11 18:29:55 +01:00
Alexander Engelsberger
1a0e697b27 Merge branch 'dev' into main 2022-01-11 18:29:32 +01:00
Alexander Engelsberger
1a17193b35 ci: add github actions (#16)
* chore: update pre-commit versions

* ci: remove old configurations

* ci: copy workflow from prototorch

* ci: run precommit for all files

* ci: add examples CPU test

* ci(test): failing example test

* ci: fix workflow definition

* ci(test): repeat failing example test

* ci: fix workflow definition

* ci(test): repeat failing example test II

* ci: fix test command

* ci: cleanup example test

* ci: remove travis badge
2022-01-11 18:28:50 +01:00
Alexander Engelsberger
aaa3c51e0a build: bump version 0.3.0 → 0.4.0 2021-12-09 15:58:16 +01:00
Jensun Ravichandran
62c5974a85 fix: correct typo in example script 2021-11-17 15:01:38 +01:00
Jensun Ravichandran
1d26226a2f fix(warning): specify dimension explicitly when calling softmin 2021-11-16 10:19:31 +01:00
Christoph
4232d0ed2a fix: spelling issues for previous commits 2021-11-15 11:43:39 +01:00
Christoph
a9edf06507 feat: ImageGTLVQ and SiameseGTLVQ with examples 2021-11-15 11:43:39 +01:00
Christoph
d3bb430104 feat: gtlvq with examples 2021-11-15 11:43:39 +01:00
Alexander Engelsberger
6ffd27d12a chore: Remove PytorchLightning CLI related code
Could be moved in a seperate plugin.
2021-10-11 15:16:12 +02:00
Alexander Engelsberger
859e2cae69 docs(dependencies): Add missing ipykernel dependency for docs 2021-10-11 15:11:53 +02:00
Alexander Engelsberger
d7ea89d47e feat: add simple test step 2021-09-10 19:19:51 +02:00
Jensun Ravichandran
fa928afe2c feat(vis): 2D EV projection for GMLVQ 2021-09-01 10:49:57 +02:00
Alexander Engelsberger
7d4a041df2 build: bump version 0.2.0 → 0.3.0 2021-08-30 20:50:03 +02:00
Alexander Engelsberger
04c51c00c6 ci: seperate build step 2021-08-30 20:44:16 +02:00
Alexander Engelsberger
62185b38cf chore: Update prototorch dependency 2021-08-30 20:32:47 +02:00
Alexander Engelsberger
7b93cd4ad5 feat(compatibility): Python3.6 compatibility 2021-08-30 20:32:40 +02:00
Alexander Engelsberger
d7834e2cc0 fix: All examples should work on CPU and GPU now 2021-08-05 11:20:02 +02:00
Alexander Engelsberger
0af8cf36f8 fix: labels where on cpu in forward pass 2021-08-05 09:14:32 +02:00
Jensun Ravichandran
f8ad1d83eb refactor: clean up abstract classes 2021-07-14 19:17:05 +02:00
Jensun Ravichandran
23a3683860 fix(doc): update outdated 2021-07-12 21:21:29 +02:00
Jensun Ravichandran
4be9fb81eb feat(model): implement MedianLVQ 2021-07-06 17:12:51 +02:00
Jensun Ravichandran
9d38123114 refactor: use GLVQLoss instead of LossLayer 2021-07-06 17:09:21 +02:00
Jensun Ravichandran
0f9f24e36a feat: add early-stopping and pruning to examples/warm_starting.py 2021-06-30 16:04:26 +02:00
Jensun Ravichandran
09e3ef1d0e fix: remove deprecated Trainer.accelerator_backend 2021-06-30 16:03:45 +02:00
Alexander Engelsberger
7b9b767113 fix: training loss is a zero dimensional tensor
Should fix the problem with EarlyStopping callback.
2021-06-25 17:07:06 +02:00
Jensun Ravichandran
f56ec44afe chore(github): update bug report issue template 2021-06-25 17:07:06 +02:00
Jensun Ravichandran
67a20124e8 chore(github): add issue templates 2021-06-25 17:07:06 +02:00
Jensun Ravichandran
72af03b991 refactor: use LinearTransform instead of torch.nn.Linear 2021-06-25 17:07:06 +02:00
Alexander Engelsberger
71602bf38a build: bump version 0.1.8 → 0.2.0 2021-06-21 16:47:17 +02:00
Jensun Ravichandran
a1d9657b91 test: remove examples/liramlvq_tecator.py temporarily 2021-06-21 16:13:41 +02:00
Jensun Ravichandran
4dc11a3737 chore(setup): require prototorch>=0.6.0 2021-06-21 15:51:07 +02:00
Alexander Engelsberger
2649e3ac31 test(examples): print error message on fail 2021-06-21 15:06:37 +02:00
Alexander Engelsberger
2b2e4a5f37 fix: examples/ng_iris.py 2021-06-21 14:59:54 +02:00
Jensun Ravichandran
72404f7c4e fix: examples/gmlvq_mnist.py 2021-06-21 14:42:28 +02:00
Jensun Ravichandran
612ee8dc6a chore(bumpversion): modify bump message 2021-06-20 19:11:29 +02:00
Jensun Ravichandran
d42693a441 refactor(api)!: merge the new api changes into dev 2021-06-20 19:00:12 +02:00
Jensun Ravichandran
e5ac50c9a7 Bump version: 0.1.7 → 0.1.8 2021-06-20 17:56:21 +02:00
Jensun Ravichandran
561119ef1d fix: python is python3.9 2021-06-20 17:50:09 +02:00
Jensun Ravichandran
f1f0b313c9 Constrain prototorch version 2021-06-20 17:40:07 +02:00
Jensun Ravichandran
b9eb88a602 Remove .swp file 2021-06-18 13:46:08 +02:00
Jensun Ravichandran
7eb496110f Use mesh2d from prototorch.utils 2021-06-18 13:43:44 +02:00
danielstaps
0a2da9ae50 Added Vis for GMLVQ with more then 2 dims using PCA (#11)
* Added Vis for GMLVQ with more then 2 dims using PCA

* Added initialization possibility to GMlVQ with PCA and one example with omega init + PCA vis of 3 dims

* test(githooks): Add githooks for automatic commit checks

Co-authored-by: staps@hs-mittweida.de <staps@hs-mittweida.de>
Co-authored-by: Alexander Engelsberger <alexanderengelsberger@gmail.com>
2021-06-18 13:28:11 +02:00
Jensun Ravichandran
4ab0a5a414 Add setup config 2021-06-17 14:51:09 +02:00
Alexander Engelsberger
8956ee75ad test(githooks): Add githooks for automatic commit checks 2021-06-16 16:16:34 +02:00
Jensun Ravichandran
29063dcec4 Update gitignore 2021-06-16 12:39:39 +02:00
Jensun Ravichandran
a37095409b [BUGFIX] examples/cbc_iris.py works again 2021-06-15 15:59:47 +02:00
Jensun Ravichandran
1b420c1f6b [BUG] LVQ1 is broken 2021-06-14 21:08:05 +02:00
Jensun Ravichandran
7ec5528ade [BUGFIX] examples/lvqmln_iris.py works again 2021-06-14 21:00:26 +02:00
Jensun Ravichandran
a44219ee47 [BUG] PLVQ seems broken 2021-06-14 20:56:38 +02:00
Jensun Ravichandran
24ebfdc667 [BUGFIX] examples/siamese_glvq_iris.py works again 2021-06-14 20:44:36 +02:00
Jensun Ravichandran
1c658cdc1b [FEATURE] Add warm-starting example 2021-06-14 20:42:57 +02:00
Jensun Ravichandran
1911d4b33e [BUGFIX] examples/lgmlvq_moons.py works again 2021-06-14 20:34:46 +02:00
Jensun Ravichandran
6197d7d5d6 [BUGFIX] examples/dynamic_pruning.py works again 2021-06-14 20:31:39 +02:00
Jensun Ravichandran
d2856383e2 [BUGFIX] examples/gng_iris.py works again 2021-06-14 20:29:31 +02:00
Jensun Ravichandran
4eafe88dc4 [BUGFIX] examples/ksom_colors.py works again 2021-06-14 20:23:07 +02:00
Jensun Ravichandran
3afced8662 [BUGFIX] examples/glvq_spiral.py works again 2021-06-14 20:19:08 +02:00
Jensun Ravichandran
68034d56f6 [BUGFIX] examples/glvq_iris.py works again 2021-06-14 20:13:25 +02:00
Jensun Ravichandran
97ec15b76a [BUGFIX] KNN works again 2021-06-14 20:09:41 +02:00
Jensun Ravichandran
69e5ff3243 Import from the newly cleaned-up prototorch namespace 2021-06-14 20:08:08 +02:00
Alexander Engelsberger
c87ed5ba8b [FEATURE] Add PLVQ model 2021-06-12 13:02:26 +02:00
Alexander Engelsberger
fc11d78b38 [REFACTOR] Rename LikelihoodRatioLVQ to SLVQ 2021-06-12 13:02:26 +02:00
Jensun Ravichandran
e62a8e6582 [BUGFIX] Log loss in NG and GNG 2021-06-11 18:50:14 +02:00
Jensun Ravichandran
ea33196a50 Update SOM example 2021-06-11 11:30:27 +02:00
Jensun Ravichandran
4ca846997a [BUGFIX] Set minimum required version of pytorch_lightning
The `check_on_train_epoch_end=True` argument to the EarlyStopping callback from
pytorch_lightning does not seem to be available in pl version 1.2.8.
2021-06-10 01:14:01 +02:00
Jensun Ravichandran
57f8bec270 Update SOM 2021-06-09 18:21:12 +02:00
Jensun Ravichandran
022d791ea5 Route initialized prototypes 2021-06-07 21:18:08 +02:00
Alexander Engelsberger
43fc7d1678 [QA] Remove unused argument from CBC 2021-06-07 21:00:58 +02:00
Jensun Ravichandran
c7b5c88776 [WIP] Add SOM 2021-06-07 18:44:15 +02:00
Jensun Ravichandran
b031382072 Update NG 2021-06-07 18:35:08 +02:00
Jensun Ravichandran
d558fa6a4a [REFACTOR] Clean up GLVQ-types 2021-06-07 17:00:38 +02:00
Jensun Ravichandran
34ffeb95bc Update how the model is printed 2021-06-07 16:58:00 +02:00
Jensun Ravichandran
3aa33fd182 Add remarkrc 2021-06-07 15:18:26 +02:00
Jensun Ravichandran
f65a665157 Update readme 2021-06-07 15:18:10 +02:00
Jensun Ravichandran
bed753a6e9 Minor aesthetic change 2021-06-05 01:23:58 +02:00
Jensun Ravichandran
b82bb54dbe Add LGMLVQ example 2021-06-04 22:58:54 +02:00
Jensun Ravichandran
acac39cff6 Fix bold markup 2021-06-04 22:39:27 +02:00
Jensun Ravichandran
19f601fac8 Bump minimum required python version to 3.8 2021-06-04 22:21:41 +02:00
Jensun Ravichandran
5d2a8226ce Update example scripts 2021-06-04 22:21:28 +02:00
Jensun Ravichandran
016fcb4060 [REFACTOR] Major cleanup 2021-06-04 22:20:32 +02:00
Jensun Ravichandran
20471bfb1c [FEATURE] Update pruning callback to re-add pruned prototypes 2021-06-04 15:56:46 +02:00
Jensun Ravichandran
42d974e08c No implicit learning rate scheduling 2021-06-04 15:55:06 +02:00
Jensun Ravichandran
b0df61d1c3 [BUGFIX] Fix examples/ng_iris.py 2021-06-03 16:34:48 +02:00
Alexander Engelsberger
47db1965ee [BUGFIX] GNG Example 2021-06-03 15:42:54 +02:00
Alexander Engelsberger
0bc385fe7b [BUGFIX] Neural Gas gets prototype intiailizer from kwargs 2021-06-03 15:24:17 +02:00
Alexander Engelsberger
358f27257d [REFACTOR] Remove prototype_initializer function from GLVQ
Fixes #9
2021-06-03 15:15:22 +02:00
Alexander Engelsberger
bda88149d4 [BUGFIX] Growing neural gas 2021-06-03 15:13:38 +02:00
Alexander Engelsberger
7379c61966 [BUGFIX] Fix image visualization for some parameter combination
image visualization was broken if add_embeding was False, but data visualization was on.
2021-06-03 15:12:51 +02:00
Alexander Engelsberger
e209bf73d5 [BUGFIX] Pruning example works on GPU now 2021-06-03 14:35:24 +02:00
Alexander Engelsberger
1b09b1d57b [BUGFIX] Probabilistic Models work on GPU now 2021-06-03 14:05:44 +02:00
Alexander Engelsberger
459f7c24be [REFACTOR] Probabilistic loss signs changed 2021-06-03 14:00:47 +02:00
Alexander Engelsberger
5918f1cc21 [BUGFIX] CLI example documentation improved 2021-06-03 13:47:20 +02:00
Alexander Engelsberger
3b02d99ebe [BUGFIX] Early stopping example works now 2021-06-03 13:38:16 +02:00
Jensun Ravichandran
64250d0938 [BUG] CLI example crashes
Running examples/cli/gmlvq.py crashes with:
```
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
    249         try:
--> 250             return self[key]
    251         except KeyError as exp:

KeyError: 'distribution'

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
~/work/repos/prototorch_models/examples/cli/gmlvq.py in <module>
     10
     11
---> 12 cli = LightningCLI(GMLVQMNIST)

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in __init__(self, model_class, datamodule_class, save_config_callback, trainer_class, trainer_defaults, seed_everything_default, description, env_prefix, env_parse, parser_kwargs, subclass_mode_model, subclass_mode_data)
    168             seed_everything(self.config['seed_everything'])
    169         self.before_instantiate_classes()
--> 170         self.instantiate_classes()
    171         self.prepare_fit_kwargs()
    172         self.before_fit()

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_classes(self)
    211         self.config_init = self.parser.instantiate_subclasses(self.config)
    212         self.instantiate_datamodule()
--> 213         self.instantiate_model()
    214         self.instantiate_trainer()
    215

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/cli.py in instantiate_model(self)
    228             self.model = self.config_init['model']
    229         else:
--> 230             self.model = self.model_class(**self.config_init.get('model', {}))
    231
    232     def instantiate_trainer(self) -> None:

~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
    307     def __init__(self, hparams, **kwargs):
    308         distance_fn = kwargs.pop("distance_fn", omega_distance)
--> 309         super().__init__(hparams, distance_fn=distance_fn, **kwargs)
    310         omega = torch.randn(self.hparams.input_dim,
    311                             self.hparams.latent_dim,

~/work/repos/prototorch_models/prototorch/models/glvq.py in __init__(self, hparams, **kwargs)
     39         # Layers
     40         self.proto_layer = LabeledComponents(
---> 41             distribution=self.hparams.distribution,
     42             initializer=self.prototype_initializer(**kwargs))
     43

~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py in __getattr__(self, key)
    250             return self[key]
    251         except KeyError as exp:
--> 252             raise AttributeError(f'Missing attribute "{key}"') from exp
    253
    254     def __setattr__(self, key, val):

AttributeError: Missing attribute "distribution"
```
2021-06-02 13:02:40 +02:00
Jensun Ravichandran
86688b26b0 Update setup.py 2021-06-02 13:01:27 +02:00
Jensun Ravichandran
ef6bcc1079 [BUG] Early stopping does not seem to work
The early stopping callback does not work as expected, and crashes at the end of
max_epochs with:

```
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py in on_train_end(self)
    155         """Called when the train ends."""
    156         for callback in self.callbacks:
--> 157             callback.on_train_end(self, self.lightning_module)
    158
    159     def on_pretrain_routine_start(self) -> None:

~/work/repos/prototorch_models/prototorch/models/callbacks.py in on_train_end(self, trainer, pl_module)
     18     def on_train_end(self, trainer, pl_module):
     19         # instead, do it at the end of training loop
---> 20         self._run_early_stopping_check(trainer, pl_module)
     21
     22

TypeError: _run_early_stopping_check() takes 2 positional arguments but 3 were given
```
2021-06-02 12:44:34 +02:00
Jensun Ravichandran
bdacc83185 [REFACTOR] Update examples/dynamic_pruning.py 2021-06-02 03:53:21 +02:00
Jensun Ravichandran
8851d1bbc9 [FEATURE] Add PruneLoserPrototypes Callback 2021-06-02 03:52:41 +02:00
Jensun Ravichandran
a3f5d7d113 Update docstring 2021-06-02 02:40:29 +02:00
Jensun Ravichandran
b2009bb563 [FEATURE] Add example to showcase dynamic pruning of prototypes 2021-06-02 02:36:37 +02:00
Jensun Ravichandran
398431e7ea Remove examples/dynamic_components.py 2021-06-02 02:35:45 +02:00
Jensun Ravichandran
8f7deb75dd [FEATURE] Log prototype win ratios over all training batches 2021-06-02 02:32:54 +02:00
Jensun Ravichandran
7743c50725 Tweak repr 2021-06-02 01:07:48 +02:00
Jensun Ravichandran
fcf3a4979c Remove unused import 2021-06-02 00:49:36 +02:00
Jensun Ravichandran
d46fe4a393 [WIP] Update CBC example 2021-06-02 00:45:33 +02:00
Jensun Ravichandran
88cfd5762e Remove unused imports in models/cbc.py 2021-06-02 00:44:35 +02:00
Jensun Ravichandran
aa42b9e331 [BUGFIX] Import missing module
models/unsupervised.py uses `pt` in line 37, but `pt` is undefined in the file.
I wonder why Python doesn't complain about this. Perhaps because unsupervised.py
is never run in isolation and `pt` is otherwise available in the namespace of
the example scripts that use unsupervised.py.
2021-06-02 00:31:57 +02:00
Jensun Ravichandran
91b57b01b1 [REFACTOR] neighbour -> neighbor 2021-06-02 00:29:45 +02:00
Jensun Ravichandran
9eb6476078 [BUG] Training unstable in examples/gng_iris.py 2021-06-02 00:21:42 +02:00
Jensun Ravichandran
98c198d463 [REFACTOR] Use LambdaLayer instead of EuclideanDistance 2021-06-02 00:21:11 +02:00
Jensun Ravichandran
ef4d70eee0 Update readme 2021-06-02 00:03:35 +02:00
Jensun Ravichandran
7e241ff7d8 [WIP] Update examples/liramlvq_tecator.py 2021-06-02 00:02:31 +02:00
Jensun Ravichandran
757f4e980d Add Local-Matrix LVQ
Also remove the use of `self.distance_fn` in favor of `self.distance_layer`.
2021-06-01 23:44:16 +02:00
Jensun Ravichandran
5ec2dd47cd Remove unused import 2021-06-01 23:40:56 +02:00
Jensun Ravichandran
930f84d3c7 Remove examples/gmlvq_iris.py 2021-06-01 23:40:15 +02:00
Jensun Ravichandran
e8cd4d765c Remove unused variable 2021-06-01 23:39:39 +02:00
Jensun Ravichandran
8403b01081 Move CELVQ to probabilistic.py 2021-06-01 23:39:06 +02:00
Jensun Ravichandran
aff6aedd60 Use the add_components API for adding prototypes 2021-06-01 23:37:45 +02:00
Jensun Ravichandran
1b6843dbbb Remove unused imports 2021-06-01 19:31:03 +02:00
Jensun Ravichandran
21023a88d7 [BUGFIX] Fix RSLVQ 2021-06-01 17:44:10 +02:00
Alexander Engelsberger
9c1a41997b [FEATURE] Add Growing Neural Gas 2021-06-01 17:19:43 +02:00
Jensun Ravichandran
1636c84778 Rename rslvq example 2021-05-31 17:56:45 +02:00
Jensun Ravichandran
27eccf44d4 Use LambdaLayer from ProtoTorch 2021-05-31 16:53:04 +02:00
Alexander Engelsberger
8f4d66edf1 [Bugfix] Fix classcount in LIRAMLVQ example 2021-05-31 11:48:23 +02:00
Alexander Engelsberger
2a218c0ede Add example for dynamic components in callbacks 2021-05-31 11:39:24 +02:00
Alexander Engelsberger
db064b5af1 Improvement of model __repr__ 2021-05-31 11:19:06 +02:00
Alexander Engelsberger
0ac4ced85d [refactor] Use functional variant of accuracy
Prevents Accuracy in `__repr__` of the models.
2021-05-31 11:12:27 +02:00
Jensun Ravichandran
e9d2075fed Sort imports in example scripts 2021-05-31 00:52:16 +02:00
Jensun Ravichandran
7b7bc3693d Merge branch 'dev' of github.com:si-cim/prototorch_models into dev 2021-05-31 00:32:49 +02:00
Jensun Ravichandran
cd73f6c427 Add examples/dynamic_components.py 2021-05-31 00:32:27 +02:00
Alexander Engelsberger
a60337ff27 [refactor] Move probabilistic to Prototorch 2021-05-28 20:39:32 +02:00
Alexander Engelsberger
e3392ee952 [refactor] DRY Probabilistic models 2021-05-28 17:13:06 +02:00
Alexander Engelsberger
dade502686 Add MNIST datamodule and training mixin factory. 2021-05-28 16:33:31 +02:00
Jensun Ravichandran
b7edee02c3 [WIP] Add CELVQ
TODO Ensure that the distances/probs corresponding to the plabels are sorted
like the target labels.
2021-05-27 17:40:16 +02:00
Alexander Engelsberger
41b2a2f496 Add model tree to documentation. 2021-05-26 21:29:37 +02:00
Alexander Engelsberger
66e3e51a52 Add references to the documentation. 2021-05-26 21:20:17 +02:00
Alexander Engelsberger
0c1f7a4772 [BUGFIX] Update paths in documentaion for LVQ 2021-05-26 16:19:10 +02:00
Alexander Engelsberger
663eb12ad7 Dummy test gets detected again 2021-05-25 22:03:38 +02:00
Jensun Ravichandran
3fa1fb54f1 Refactor tests 2021-05-25 21:28:05 +02:00
Jensun Ravichandran
cc49f26b77 Remove normalization transform from cli example 2021-05-25 21:13:37 +02:00
Jensun Ravichandran
db965541fd Update example 2021-05-25 20:57:54 +02:00
Jensun Ravichandran
d091dea6a1 Update tutorial 2021-05-25 20:54:07 +02:00
Jensun Ravichandran
d411e52be4 Refactor non-gradient-lvq models into lvq.py 2021-05-25 20:37:34 +02:00
Alexander Engelsberger
32d6f95db0 Add RSLVQ and LikelihoodLVQ 2021-05-25 20:26:15 +02:00
Jensun Ravichandran
139109804f [BUGFIX] Use _forward in LVQ1 and LVQ21 2021-05-25 17:43:37 +02:00
Alexander Engelsberger
2cc11ae2e3 Ignore .pt files 2021-05-25 16:46:31 +02:00
Alexander Engelsberger
72e064338c Use 'num_' in all variable names 2021-05-25 15:41:10 +02:00
Alexander Engelsberger
e7e6bf9173 Fix failing example 2021-05-21 18:54:47 +02:00
Alexander Engelsberger
2aa631f4e6 Improve example test script (with failing example) 2021-05-21 18:48:37 +02:00
Alexander Engelsberger
c6992da123 Travis didn't use the dev branch 2021-05-21 18:35:00 +02:00
Alexander Engelsberger
dcbd0c1e5c Travis didn't use the dev branch 2021-05-21 18:30:10 +02:00
Alexander Engelsberger
b8bca71206 [Travis] Use dev branch of prototorch for tests. 2021-05-21 18:22:29 +02:00
Alexander Engelsberger
419eca46af Add example test script 2021-05-21 18:16:17 +02:00
Alexander Engelsberger
5b12629bd9 All examples use argparse 2021-05-21 17:55:55 +02:00
Alexander Engelsberger
b60db3174a LightningCLI Example. 2021-05-21 17:13:15 +02:00
Alexander Engelsberger
8ce18f83ce Add prototype_initializer function to GLVQ
This allows overwriting it inside subclasses.
2021-05-21 17:13:10 +02:00
Alexander Engelsberger
7b4f7d84e0 Update Documentation
Clean up project
2021-05-21 15:42:45 +02:00
Jensun Ravichandran
a5e086ce0d Refactor code 2021-05-21 13:33:57 +02:00
Jensun Ravichandran
0611f81aba Update models namespace 2021-05-21 13:11:59 +02:00
Jensun Ravichandran
a9382dcd9b Add get_prototype_grid method 2021-05-21 13:11:48 +02:00
Jensun Ravichandran
0933a88a1b Fix ImageCBC bug 2021-05-21 13:11:36 +02:00
Jensun Ravichandran
88a34a06ef [WIP] Update CBC implementation to use SiameseGLVQ 2021-05-20 17:36:00 +02:00
Jensun Ravichandran
49f9a12b5f Update mnist example 2021-05-20 17:35:07 +02:00
Jensun Ravichandran
de63eaf15a Fix numpy issue in vis.py 2021-05-20 17:33:19 +02:00
Jensun Ravichandran
16dc3cf4eb Update image visualization 2021-05-20 16:07:16 +02:00
Jensun Ravichandran
df061cc2ff Refactor code 2021-05-20 14:40:02 +02:00
Alexander Engelsberger
969fb34cc3 Accumulate test loss 2021-05-20 14:20:23 +02:00
Alexander Engelsberger
0204f5eab6 Log test accuracy. 2021-05-20 14:03:31 +02:00
Alexander Engelsberger
b7fc5df386 Log test loss. 2021-05-20 13:47:20 +02:00
Alexander Engelsberger
faf1a88f99 [Bugfix] Remove optimzer_idx from validation and test. 2021-05-20 13:17:27 +02:00
Jensun Ravichandran
5ffbd43a7c Refactor into shared_step 2021-05-19 16:57:51 +02:00
Jensun Ravichandran
fdf9443a2c Add validation and test logic 2021-05-19 16:30:19 +02:00
Jensun Ravichandran
7700bb7f8d [DOC] Ignore Sphinx warnings until prototorch is bumped
readthedocs build fails because of missing function (`get_flat`) that's not
available via PyPI yet (See
https://readthedocs.org/projects/prototorch-models/builds/13795474/). The
temporary solution until this becomes available is therefore to ignore it and
build the rest of the docs.
2021-05-18 20:08:14 +02:00
Jensun Ravichandran
eefec19c9b Custom non-gradient training 2021-05-18 19:49:16 +02:00
Jensun Ravichandran
246719b837 [DOC] Add tutorial 2021-05-18 19:41:58 +02:00
Jensun Ravichandran
a14e3aa611 Add argparse to mnist example script 2021-05-18 10:17:51 +02:00
Jensun Ravichandran
00cdacf7ae Fix example script 2021-05-18 10:15:38 +02:00
Jensun Ravichandran
4957e821f6 Close matplotlib figure on train end 2021-05-18 10:13:22 +02:00
Jensun Ravichandran
538256dcb7 Small changes 2021-05-17 19:37:42 +02:00
Jensun Ravichandran
d812bb0620 Update examples 2021-05-17 17:03:37 +02:00
Jensun Ravichandran
81346785bd Cleanup models
Siamese architectures no longer accept a `backbone_module`. They have to be
initialized with an pre-initialized backbone object instead. This is so that the
visualization callbacks could use the very same object for visualization
purposes. Also, there's no longer a dependent copy of the backbone. It is
managed simply with `requires_grad` instead.
2021-05-17 17:00:23 +02:00
Jensun Ravichandran
7a87636ad7 Update KNN 2021-05-17 16:59:35 +02:00
Jensun Ravichandran
77b7b59bad Clean visualization callbacks 2021-05-17 16:59:22 +02:00
Jensun Ravichandran
6e7d80be88 [BUGFIX] Fix siamese visualization callback 2021-05-15 12:52:44 +02:00
Jensun Ravichandran
b7684ae512 predict_latent no longer returns numpy 2021-05-15 12:52:16 +02:00
Jensun Ravichandran
ebc42a4aa8 Set gpus=0 in examples 2021-05-15 12:43:00 +02:00
Jensun Ravichandran
c639836537 Update README.md 2021-05-14 22:19:57 +02:00
Jensun Ravichandran
d36d685115 Add links 2021-05-14 12:58:38 +02:00
Alexander Engelsberger
b341096757 Add basic documentation files. 2021-05-13 15:22:32 +02:00
Alexander Engelsberger
0eac2ce326 Examples use GPUs if available. 2021-05-13 15:22:01 +02:00
Jensun Ravichandran
8f9c29bd2b [BUGFIX] Remove incorrect import statement 2021-05-12 16:45:01 +02:00
Jensun Ravichandran
ca39aa00d5 Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
2021-05-12 16:36:22 +02:00
Alexander Engelsberger
1498c4bde5 Bump version: 0.1.6 → 0.1.7 2021-05-11 17:18:29 +02:00
Jensun Ravichandran
59b8ab6643 Add knn 2021-05-11 17:22:02 +02:00
Jensun Ravichandran
2a4f184163 Update example scripts 2021-05-11 16:15:08 +02:00
Jensun Ravichandran
265e74dd31 Require prototorch>=0.4.2 2021-05-11 16:14:47 +02:00
Jensun Ravichandran
daad018a78 Update readme 2021-05-11 16:14:23 +02:00
Jensun Ravichandran
eab1ec72c2 Change optimizer using kwargs 2021-05-11 16:13:00 +02:00
Jensun Ravichandran
b38acd58a8 [BUGFIX] Fix visualization callbacks bug 2021-05-11 16:09:27 +02:00
Alexander Engelsberger
e87563e10d Bump version: 0.1.5 → 0.1.6 2021-05-11 13:41:26 +02:00
Alexander Engelsberger
767206f905 Define minimum prototorch version in setup 2021-05-11 13:41:09 +02:00
Alexander Engelsberger
3fa6378c4d Add LVQ1 and LVQ2.1 Models. 2021-05-11 13:26:13 +02:00
Alexander Engelsberger
30ee287ecc Bump version: 0.1.4 → 0.1.5 2021-05-10 17:13:00 +02:00
Alexander Engelsberger
e323f9d4ca Fix long description for pypi. 2021-05-10 17:12:54 +02:00
Alexander Engelsberger
f49db0bf2c Bump version: 0.1.3 → 0.1.4 2021-05-10 17:06:28 +02:00
Alexander Engelsberger
db38667306 Fix Travis configuration 2021-05-10 17:06:23 +02:00
Alexander Engelsberger
54a8494d86 Bump version: 0.1.2 → 0.1.3 2021-05-10 17:04:20 +02:00
Alexander Engelsberger
bf310be97c Bump version: 0.1.1 → 0.1.2 2021-05-10 16:47:33 +02:00
Alexander Engelsberger
32ae1b7862 Add Build Badge. 2021-05-10 16:47:28 +02:00
Jensun Ravichandran
dfddb92aba Dummy change 2021-05-10 16:47:37 +02:00
Alexander Engelsberger
4a38bb2bfe Corrected Badge Image Url. 2021-05-10 16:40:08 +02:00
Alexander Engelsberger
6680d4b9df Add PyPi Badge. 2021-05-10 16:39:02 +02:00
Alexander Engelsberger
1ae2b41edd Bump version: 0.1.0 → 0.1.1 2021-05-10 16:26:21 +02:00
Alexander Engelsberger
9300a6d14d Dummy Test to enable CI. 2021-05-10 16:19:51 +02:00
Jensun Ravichandran
3d42876df1 Merge pull request #7 from si-cim/main
Merge branch 'dev' into main
2021-05-10 16:07:08 +02:00
Alexander Engelsberger
fbadacdbca Merge branch 'dev' into main 2021-05-10 16:02:07 +02:00
75 changed files with 5958 additions and 1039 deletions

View File

@@ -1,11 +1,15 @@
[bumpversion]
current_version = 0.1.0
current_version = 1.0.0a4
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
serialize =
{major}.{minor}.{patch}-{release}
{major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py]
[bumpversion:file:./prototorch/models/__init__.py]
[bumpversion:file:./docs/source/conf.py]

View File

@@ -1,15 +0,0 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

View File

@@ -1,2 +0,0 @@
comment:
require_changes: yes

38
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,38 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Steps to reproduce the behavior**
1. ...
2. Run script '...' or this snippet:
```python
import prototorch as pt
...
```
3. See errors
**Expected behavior**
A clear and concise description of what you expected to happen.
**Observed behavior**
A clear and concise description of what actually happened.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**System and version information**
- OS: [e.g. Ubuntu 20.10]
- ProtoTorch Version: [e.g. 0.4.0]
- Python Version: [e.g. 3.9.5]
**Additional context**
Add any other context about the problem here.

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

25
.github/workflows/examples.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
# Thi workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: examples
on:
push:
paths:
- 'examples/**.py'
jobs:
cpu:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Run examples
run: |
./tests/test_examples.sh examples/

75
.github/workflows/pythonapp.yml vendored Normal file
View File

@@ -0,0 +1,75 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: tests
on:
push:
pull_request:
branches: [ master ]
jobs:
style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v2.0.3
compatibility:
needs: style
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: |
pytest
publish_pypi:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs: compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
pip install wheel
- name: Build package
run: python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

19
.gitignore vendored
View File

@@ -128,8 +128,19 @@ dmypy.json
# Pyre type checker
.pyre/
# Datasets
datasets/
.vscode/
# PyTorch-Lightning
lightning_logs/
# Vim
*~
*.swp
*.swo
# Pytorch Models or Weights
# If necessary make exceptions for single pretrained models
*.pt
# Artifacts created by ProtoTorch Models
datasets/
lightning_logs/
examples/_*.py
examples/_*.ipynb

54
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,54 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
hooks:
- id: trailing-whitespace
exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v1.4
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.950
hooks:
- id: mypy
files: prototorch
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0
hooks:
- id: python-use-type-annotations
- id: python-no-log-warn
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
hooks:
- id: pyupgrade
- repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]

27
.readthedocs.yml Normal file
View File

@@ -0,0 +1,27 @@
# .readthedocs.yml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
fail_on_warning: true
# Build documentation with MkDocs
# mkdocs:
# configuration: mkdocs.yml
# Optionally build your docs in additional formats such as PDF and ePub
formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.9
install:
- method: pip
path: .
extra_requirements:
- all

7
.remarkrc Normal file
View File

@@ -0,0 +1,7 @@
{
"plugins": [
"remark-preset-lint-recommended",
["remark-lint-list-item-indent", false],
["no-emphasis-as-header", true]
]
}

View File

@@ -1,21 +0,0 @@
dist: bionic
sudo: false
language: python
python: 3.8
cache:
directories:
- "./tests/artifacts"
install:
- pip install .[all] --progress-bar off
script:
- coverage run -m pytest
after_success:
- bash <(curl -s https://codecov.io/bash)
deploy:
provider: pypi
username: __token__
password:
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
on:
tags: true
skip_existing: true

View File

@@ -1,27 +1,58 @@
# ProtoTorch Models
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch_models?color=yellow&label=version)](https://github.com/si-cim/prototorch_models/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch_models)](https://pypi.org/project/prototorch_models/)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch_models)](https://github.com/si-cim/prototorch_models/blob/master/LICENSE)
Pre-packaged prototype-based machine learning models using ProtoTorch and
PyTorch-Lightning.
## Installation
To install this plugin, first install
[ProtoTorch](https://github.com/si-cim/prototorch) with:
To install this plugin, simply run the following command:
```sh
git clone https://github.com/si-cim/prototorch.git && cd prototorch
pip install -e .
pip install prototorch_models
```
and then install the plugin itself with:
**Installing the models plugin should automatically install a suitable version
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
be available for use in your Python environment as `prototorch.models`.
```sh
git clone https://github.com/si-cim/prototorch_models.git && cd prototorch_models
pip install -e .
```
## Available models
The plugin should then be available for use in your Python environment as
`prototorch.models`.
### LVQ Family
- Learning Vector Quantization 1 (LVQ1)
- Generalized Learning Vector Quantization (GLVQ)
- Generalized Relevance Learning Vector Quantization (GRLVQ)
- Generalized Matrix Learning Vector Quantization (GMLVQ)
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
- Localized and Generalized Matrix Learning Vector Quantization (LGMLVQ)
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
- Siamese GLVQ
- Cross-Entropy Learning Vector Quantization (CELVQ)
- Soft Learning Vector Quantization (SLVQ)
- Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ)
- Median-LVQ
### Other
- k-Nearest Neighbors (KNN)
- Neural Gas (NG)
- Growing Neural Gas (GNG)
## Work in Progress
- Classification-By-Components Network (CBC)
- Learning Vector Quantization 2.1 (LVQ2.1)
- Self-Organizing-Map (SOM)
## Planned models
- Generalized Tangent Learning Vector Quantization (GTLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ)
## Development setup
@@ -50,31 +81,26 @@ pip install -e .[all] # \[all\] if you are using zsh or MacOS
```
To assist in the development process, you may also find it useful to install
`yapf`, `isort` and `autoflake`. You can install them easily with `pip`.
`yapf`, `isort` and `autoflake`. You can install them easily with `pip`. **Also,
please avoid installing Tensorflow in this environment. It is known to cause
problems with PyTorch-Lightning.**
## Available models
## Contribution
- Generalized Learning Vector Quantization (GLVQ)
- Generalized Relevance Learning Vector Quantization (GRLVQ)
- Generalized Matrix Learning Vector Quantization (GMLVQ)
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
- Siamese GLVQ
- Neural Gas (NG)
This repository contains definition for [git hooks](https://githooks.com).
[Pre-commit](https://pre-commit.com) is automatically installed as development
dependency with prototorch or you can install it manually with `pip install
pre-commit`.
## Work in Progress
Please install the hooks by running:
```bash
pre-commit install
pre-commit install --hook-type commit-msg
```
before creating the first commit.
- Classification-By-Components Network (CBC)
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
## Planned models
- Local-Matrix GMLVQ
- Generalized Tangent Learning Vector Quantization (GTLVQ)
- Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ)
- K-Nearest Neighbors (KNN)
- Learning Vector Quantization 1 (LVQ1)
The commit will fail if the commit message does not follow the specification
provided [here](https://www.conventionalcommits.org/en/v1.0.0/#specification).
## FAQ

20
docs/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= python3 -m sphinx
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 191 KiB

209
docs/source/conf.py Normal file
View File

@@ -0,0 +1,209 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath("../../"))
# -- Project information -----------------------------------------------------
project = "ProtoTorch Models"
copyright = "2021, Jensun Ravichandran"
author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "1.0.0-a4"
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
needs_sphinx = "1.6"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named "sphinx.ext.*") or your custom
# ones.
extensions = [
"recommonmark",
"nbsphinx",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
"sphinxcontrib.katex",
"sphinxcontrib.bibtex",
]
# https://nbsphinx.readthedocs.io/en/0.8.5/custom-css.html#For-All-Pages
nbsphinx_prolog = """
.. raw:: html
<style>
.nbinput .prompt,
.nboutput .prompt {
display: none;
}
</style>
"""
# katex_prerender = True
katex_prerender = False
napoleon_use_ivar = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = [".rst", ".md"]
# The master toctree document.
master_doc = "index"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use. Choose from:
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
# "monokai", "perldoc", "pastie", "borland", "trac", "native", "fruity", "bw",
# "vim", "vs", "tango", "rrt", "xcode", "igor", "paraiso-light", "paraiso-dark",
# "lovelace", "algol", "algol_nu", "arduino", "rainbo w_dash", "abap",
# "solarized-dark", "solarized-light", "sas", "stata", "stata-light",
# "stata-dark", "inkpot"]
pygments_style = "monokai"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# Disable docstring inheritance
autodoc_inherit_docstrings = False
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
# https://sphinx-themes.org/
html_theme = "sphinx_rtd_theme"
html_logo = "_static/img/logo.png"
html_theme_options = {
"logo_only": True,
"display_version": True,
"prev_next_buttons_location": "bottom",
"style_external_links": False,
"style_nav_header_background": "#ffffff",
# Toc options
"collapse_navigation": True,
"sticky_navigation": True,
"navigation_depth": 4,
"includehidden": True,
"titles_only": False,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_css_files = [
"https://cdn.jsdelivr.net/npm/katex@0.11.1/dist/katex.min.css",
]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "protoflowdoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ("letterpaper" or "a4paper").
#
# "papersize": "letterpaper",
# The font size ("10pt", "11pt" or "12pt").
#
# "pointsize": "10pt",
# Additional stuff for the LaTeX preamble.
#
# "preamble": "",
# Latex figure (float) alignment
#
# "figure_align": "htbp",
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(
master_doc,
"prototorch.tex",
"ProtoTorch Documentation",
"Jensun Ravichandran",
"manual",
),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "ProtoTorch Models",
"ProtoTorch Models Plugin Documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
master_doc,
"prototorch models",
"ProtoTorch Models Plugin Documentation",
author,
"prototorch models",
"Prototype-based machine learning Models in ProtoTorch.",
"Miscellaneous",
),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"torch": ('https://pytorch.org/docs/stable/', None),
"pytorch_lightning":
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
}
# -- Options for Epub output ----------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-epub-output
epub_cover = ()
version = release
# -- Options for Bibliography -------------------------------------------
bibtex_bibfiles = ['refs.bib']
bibtex_reference_style = 'author_year'

7
docs/source/custom.rst Normal file
View File

@@ -0,0 +1,7 @@
.. Customize the Models
Abstract Models
========================================
.. automodule:: prototorch.models.abstract
:members:
:undoc-members:

40
docs/source/index.rst Normal file
View File

@@ -0,0 +1,40 @@
.. ProtoTorch Models documentation master file
ProtoTorch Models Plugins
========================================
.. toctree::
:hidden:
:maxdepth: 3
self
tutorial.ipynb
.. toctree::
:hidden:
:maxdepth: 3
:caption: Library
library
.. toctree::
:hidden:
:maxdepth: 3
:caption: Customize
custom
About
-----------------------------------------
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
for `Prototorch <https://github.com/si-cim/prototorch>`_. It implements common
prototype-based Machine Learning algorithms using `PyTorch-Lightning
<https://www.pytorchlightning.ai/>`_.
Library
-----------------------------------------
Prototorch Models delivers many application ready models.
These models have been published in the past and have been adapted to the Prototorch library.
Customizable
-----------------------------------------
Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch.

117
docs/source/library.rst Normal file
View File

@@ -0,0 +1,117 @@
.. Available Models
Models
========================================
.. image:: _static/img/model_tree.png
:width: 600
Unsupervised Methods
-----------------------------------------
.. autoclass:: prototorch.models.knn.KNN
:members:
.. autoclass:: prototorch.models.unsupervised.NeuralGas
:members:
.. autoclass:: prototorch.models.unsupervised.GrowingNeuralGas
:members:
Classical Learning Vector Quantization
-----------------------------------------
Original LVQ models introduced by :cite:t:`kohonen1989`.
These heuristic algorithms do not use gradient descent.
.. autoclass:: prototorch.models.lvq.LVQ1
:members:
.. autoclass:: prototorch.models.lvq.LVQ21
:members:
It is also possible to use the GLVQ structure as shown by :cite:t:`sato1996` in chapter 4.
This allows the use of gradient descent methods.
.. autoclass:: prototorch.models.glvq.GLVQ1
:members:
.. autoclass:: prototorch.models.glvq.GLVQ21
:members:
Generalized Learning Vector Quantization
-----------------------------------------
:cite:t:`sato1996` presented a LVQ variant with a cost function called GLVQ.
This allows the use of gradient descent methods.
.. autoclass:: prototorch.models.glvq.GLVQ
:members:
The cost function of GLVQ can be extended by a learnable dissimilarity.
These learnable dissimilarities assign relevances to each data dimension during the learning phase.
For example GRLVQ :cite:p:`hammer2002` and GMLVQ :cite:p:`schneider2009` .
.. autoclass:: prototorch.models.glvq.GRLVQ
:members:
.. autoclass:: prototorch.models.glvq.GMLVQ
:members:
The dissimilarity from GMLVQ can be interpreted as a projection into another dataspace.
Applying this projection only to the data results in LVQMLN
.. autoclass:: prototorch.models.glvq.LVQMLN
:members:
The projection idea from GMLVQ can be extended to an arbitrary transformation with learnable parameters.
.. autoclass:: prototorch.models.glvq.SiameseGLVQ
:members:
Probabilistic Models
--------------------------------------------
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
For a test sample they return a distribution instead of a class assignment.
The following two algorihms were presented by :cite:t:`seo2003` .
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
.. autoclass:: prototorch.models.probabilistic.SLVQ
:members:
.. autoclass:: prototorch.models.probabilistic.RSLVQ
:members:
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation.
And second use divergence as loss function.
.. autoclass:: prototorch.models.probabilistic.PLVQ
:members:
Classification by Component
--------------------------------------------
The Classification by Component (CBC) has been introduced by :cite:t:`saralajew2019` .
In a CBC architecture there is no class assigned to the prototypes.
Instead the dissimilarities are used in a reasoning process, that favours or rejects a class by a learnable degree.
The output of a CBC network is a probability distribution over all classes.
.. autoclass:: prototorch.models.cbc.CBC
:members:
.. autoclass:: prototorch.models.cbc.ImageCBC
:members:
Visualization
========================================
Visualization is very specific to its application.
PrototorchModels delivers visualization for two dimensional data and image data.
The visulizations can be shown in a seperate window and inside a tensorboard.
.. automodule:: prototorch.models.vis
:members:
:undoc-members:
Bibliography
========================================
.. bibliography::

72
docs/source/refs.bib Normal file
View File

@@ -0,0 +1,72 @@
@article{sato1996,
title={Generalized learning vector quantization},
author={Sato, Atsushi and Yamada, Keiji},
journal={Advances in neural information processing systems},
pages={423--429},
year={1996},
publisher={MORGAN KAUFMANN PUBLISHERS},
url={http://papers.nips.cc/paper/1113-generalized-learning-vector-quantization.pdf},
}
@book{kohonen1989,
doi = {10.1007/978-3-642-88163-3},
year = {1989},
publisher = {Springer Berlin Heidelberg},
author = {Teuvo Kohonen},
title = {Self-Organization and Associative Memory}
}
@inproceedings{saralajew2019,
author = {Saralajew, Sascha and Holdijk, Lars and Rees, Maike and Asan, Ebubekir and Villmann, Thomas},
booktitle = {Advances in Neural Information Processing Systems},
title = {Classification-by-Components: Probabilistic Modeling of Reasoning over a Set of Components},
url = {https://proceedings.neurips.cc/paper/2019/file/dca5672ff3444c7e997aa9a2c4eb2094-Paper.pdf},
volume = {32},
year = {2019}
}
@article{seo2003,
author = {Seo, Sambu and Obermayer, Klaus},
title = "{Soft Learning Vector Quantization}",
journal = {Neural Computation},
volume = {15},
number = {7},
pages = {1589-1604},
year = {2003},
month = {07},
doi = {10.1162/089976603321891819},
}
@article{hammer2002,
title = {Generalized relevance learning vector quantization},
journal = {Neural Networks},
volume = {15},
number = {8},
pages = {1059-1068},
year = {2002},
doi = {https://doi.org/10.1016/S0893-6080(02)00079-5},
author = {Barbara Hammer and Thomas Villmann},
}
@article{schneider2009,
author = {Schneider, Petra and Biehl, Michael and Hammer, Barbara},
title = "{Adaptive Relevance Matrices in Learning Vector Quantization}",
journal = {Neural Computation},
volume = {21},
number = {12},
pages = {3532-3561},
year = {2009},
month = {12},
doi = {10.1162/neco.2009.11-08-908},
}
@InProceedings{villmann2018,
author="Villmann, Andrea
and Kaden, Marika
and Saralajew, Sascha
and Villmann, Thomas",
title="Probabilistic Learning Vector Quantization with Cross-Entropy for Probabilistic Class Assignments in Classification Learning",
booktitle="Artificial Intelligence and Soft Computing",
year="2018",
publisher="Springer International Publishing",
}

645
docs/source/tutorial.ipynb Normal file
View File

@@ -0,0 +1,645 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7ac5eff0",
"metadata": {},
"source": [
"# A short tutorial for the `prototorch.models` plugin"
]
},
{
"cell_type": "markdown",
"id": "beb83780",
"metadata": {},
"source": [
"## Introduction"
]
},
{
"cell_type": "markdown",
"id": "43b74278",
"metadata": {},
"source": [
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
"\n",
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
"\n",
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
]
},
{
"cell_type": "markdown",
"id": "4e5d1fad",
"metadata": {},
"source": [
"## Basics"
]
},
{
"cell_type": "markdown",
"id": "1244b66b",
"metadata": {},
"source": [
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcb88e8a",
"metadata": {},
"outputs": [],
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch"
]
},
{
"cell_type": "markdown",
"id": "1adbe2f8",
"metadata": {},
"source": [
"### Building Models"
]
},
{
"cell_type": "markdown",
"id": "96663ab1",
"metadata": {},
"source": [
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "819ba756",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" hparams=dict(distribution=[1, 1, 1]),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b37e97c",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "d2c86903",
"metadata": {},
"source": [
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
"\n",
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
]
},
{
"cell_type": "markdown",
"id": "45806052",
"metadata": {},
"source": [
"### Data"
]
},
{
"cell_type": "markdown",
"id": "9d62c4c6",
"metadata": {},
"source": [
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "504df02c",
"metadata": {},
"outputs": [],
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b8e7756",
"metadata": {},
"outputs": [],
"source": [
"type(train_ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bce43afa",
"metadata": {},
"outputs": [],
"source": [
"train_ds.data.shape, train_ds.targets.shape"
]
},
{
"cell_type": "markdown",
"id": "26a83328",
"metadata": {},
"source": [
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67b80fbe",
"metadata": {},
"outputs": [],
"source": [
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1185f31",
"metadata": {},
"outputs": [],
"source": [
"type(train_loader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b5a8963",
"metadata": {},
"outputs": [],
"source": [
"x_batch, y_batch = next(iter(train_loader))\n",
"print(f\"{x_batch=}, {y_batch=}\")"
]
},
{
"cell_type": "markdown",
"id": "dd492ee2",
"metadata": {},
"source": [
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
]
},
{
"cell_type": "markdown",
"id": "5176b055",
"metadata": {},
"source": [
"### Training"
]
},
{
"cell_type": "markdown",
"id": "46a7a506",
"metadata": {},
"source": [
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "279e75b7",
"metadata": {},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e496b492",
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(model, train_loader)"
]
},
{
"cell_type": "markdown",
"id": "497fbff6",
"metadata": {},
"source": [
"### From data to a trained model - a very minimal example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab069c5d",
"metadata": {},
"outputs": [],
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
"\n",
"model = pt.models.GLVQ(\n",
" dict(distribution=(3, 2), lr=0.1),\n",
" prototypes_initializer=pt.initializers.SMCI(train_ds),\n",
")\n",
"\n",
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
"trainer.fit(model, train_loader)"
]
},
{
"cell_type": "markdown",
"id": "30c71a93",
"metadata": {},
"source": [
"### Saving/Loading trained models"
]
},
{
"cell_type": "markdown",
"id": "f74ed2c1",
"metadata": {},
"source": [
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3156658d",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
"trainer.save_checkpoint(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1c34055",
"metadata": {},
"outputs": [],
"source": [
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
]
},
{
"cell_type": "markdown",
"id": "bbbb08e9",
"metadata": {},
"source": [
"### Visualizing decision boundaries in 2D"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53ca52dc",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
]
},
{
"cell_type": "markdown",
"id": "8373531f",
"metadata": {},
"source": [
"### Saving/Loading trained weights"
]
},
{
"cell_type": "markdown",
"id": "937bc458",
"metadata": {},
"source": [
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f2035af",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
"torch.save(model.state_dict(), ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1206021a",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" dict(distribution=(3, 2)),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f2a4beb",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "528d2fc2",
"metadata": {},
"outputs": [],
"source": [
"torch.load(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec817e6b",
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a208eab7",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "f8de748f",
"metadata": {},
"source": [
"## Advanced"
]
},
{
"cell_type": "markdown",
"id": "53a64063",
"metadata": {},
"source": [
"### Warm-start a model with prototypes learned from another model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3177c277",
"metadata": {},
"outputs": [],
"source": [
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
"model = pt.models.SiameseGMLVQ(\n",
" dict(input_dim=2,\n",
" latent_dim=2,\n",
" distribution=(3, 2),\n",
" proto_lr=0.0001,\n",
" bb_lr=0.0001),\n",
" optimizer=torch.optim.Adam,\n",
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8baee9a2",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc203088",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "1f6a33a5",
"metadata": {},
"source": [
"### Initializing prototypes with a subset of a dataset (along with transformations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "946ce341",
"metadata": {},
"outputs": [],
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"from torchvision import transforms\n",
"from torchvision.datasets import MNIST\n",
"from torchvision.utils import make_grid"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "510d9bd4",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea7c1228",
"metadata": {},
"outputs": [],
"source": [
"train_ds = MNIST(\n",
" \"~/datasets\",\n",
" train=True,\n",
" download=True,\n",
" transform=transforms.Compose([\n",
" transforms.RandomHorizontalFlip(p=1.0),\n",
" transforms.RandomVerticalFlip(p=1.0),\n",
" transforms.ToTensor(),\n",
" ]),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b9eaf5c",
"metadata": {},
"outputs": [],
"source": [
"s = int(0.05 * len(train_ds))\n",
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c32c9f2",
"metadata": {},
"outputs": [],
"source": [
"init_ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68a9a8b9",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.ImageGLVQ(\n",
" dict(distribution=(10, 1)),\n",
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f23df86",
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(model.get_prototype_grid(num_columns=5))"
]
},
{
"cell_type": "markdown",
"id": "1c23c7b2",
"metadata": {},
"source": [
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30780927",
"metadata": {},
"outputs": [],
"source": [
"protos, plabels = pt.components.LabeledComponents(\n",
" distribution=(10, 5),\n",
" components_initializer=pt.initializers.SMCI(init_ds),\n",
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
")()\n",
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
]
},
{
"cell_type": "markdown",
"id": "4fa69f92",
"metadata": {},
"source": [
"## FAQs"
]
},
{
"cell_type": "markdown",
"id": "fa20f9ac",
"metadata": {},
"source": [
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
"\n",
"For prototype models, the prototypes can be retrieved (as `torch.tensor`) as `model.prototypes`. You can convert it to a NumPy Array by calling `.numpy()` on the tensor if required.\n",
"\n",
"```python\n",
">>> model.prototypes.numpy()\n",
"```\n",
"\n",
"Similarly, the labels of the prototypes can be retrieved via `model.prototype_labels`.\n",
"\n",
"```python\n",
">>> model.prototype_labels\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "ba8215bf",
"metadata": {},
"source": [
"### How do I make inferences/predictions/recall with my trained model?\n",
"\n",
"The models under [prototorch.models](https://github.com/si-cim/prototorch_models) provide a `.predict(x)` method for making predictions. This returns the predicted class labels. It is essential that the input to this method is a `torch.tensor` and not a NumPy array. Model instances are also callable. So, you could also just say `model(x)` as if `model` were just a function. However, this returns a (pseudo)-probability distribution over the classes.\n",
"\n",
"#### Example\n",
"\n",
"```python\n",
">>> y_pred = model.predict(torch.Tensor(x_train)) # returns class labels\n",
"```\n",
"or, simply\n",
"```python\n",
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,46 +1,66 @@
"""CBC example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
train_loader = DataLoader(train_ds, batch_size=32)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
num_components=5,
component_initializer=pt.components.SSI(train_ds, noise=0.01),
lr=0.01,
distribution=[1, 0, 3],
margin=0.1,
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.CBC(hparams)
model = CBC(
hparams,
components_initializer=pt.initializers.SSCI(train_ds, noise=0.1),
reasonings_initializer=pt.initializers.
PurePositiveReasoningsInitializer(),
)
# Callbacks
dvis = pt.models.VisCBC2D(data=(x_train, y_train),
title="CBC Iris Example")
vis = VisCBC2D(
data=train_ds,
title="CBC Iris Example",
resolution=100,
axis_off=True,
)
# Setup trainer
trainer = pl.Trainer(
max_epochs=200,
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
dvis,
vis,
],
detect_anomaly=True,
log_every_n_steps=1,
max_epochs=1000,
)
# Training loop

View File

@@ -0,0 +1,99 @@
"""Dynamically prune 'loser' prototypes in GLVQ-type models."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
CELVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
num_classes = 4
num_features = 2
num_clusters = 1
train_ds = pt.datasets.Random(
num_samples=500,
num_classes=num_classes,
num_features=num_features,
num_clusters=num_clusters,
separation=3.0,
seed=42,
)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=256)
# Hyperparameters
prototypes_per_class = num_clusters * 5
hparams = dict(
distribution=(num_classes, prototypes_per_class),
lr=0.2,
)
# Initialize the model
model = CELVQ(
hparams,
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
# Callbacks
vis = VisGLVQ2D(train_ds)
pruning = PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1%
idle_epochs=20, # pruning too early may cause problems
prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch
frequency=1, # prune every epoch
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
],
detect_anomaly=True,
log_every_n_steps=1,
max_epochs=1000,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,40 +1,79 @@
"""GLVQ example using the Iris dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
train_loader = DataLoader(train_ds, batch_size=64, num_workers=4)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=pt.components.SMI(train_ds),
distribution={
"num_classes": 3,
"per_class": 4
},
lr=0.01,
)
# Initialize the model
model = pt.models.GLVQ(hparams)
model = GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
vis = VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[vis],
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./glvq_iris.ckpt")
# Load saved model
new_model = GLVQ.load_from_checkpoint(
checkpoint_path="./glvq_iris.ckpt",
strict=False,
)
logging.info(new_model)

View File

@@ -1,51 +0,0 @@
"""GLVQ example using the spiral dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Spiral(n_samples=600, noise=0.6)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=256)
# Hyperparameters
hparams = dict(
nclasses=2,
prototypes_per_class=20,
prototype_initializer=pt.components.SSI(train_ds, noise=1e-7),
transfer_function="sigmoid_beta",
transfer_beta=10.0,
lr=0.01,
)
# Initialize the model
model = pt.models.GLVQ(hparams)
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
snan = StopOnNaN(model.proto_layer.components)
# Setup trainer
trainer = pl.Trainer(
max_epochs=200,
callbacks=[vis, snan],
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,37 +1,73 @@
"""GMLVQ example using all four dimensions of the Iris dataset."""
"""GMLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
train_ds = pt.datasets.Iris()
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=1,
input_dim=x_train.shape[1],
latent_dim=x_train.shape[1],
prototype_initializer=pt.components.SMI(train_ds),
lr=0.01,
input_dim=4,
latent_dim=4,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.GMLVQ(hparams)
model = GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 4)
# Callbacks
vis = VisGMLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(max_epochs=100)
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)
# Display the Lambda matrix
model.show_lambda()

112
examples/gmlvq_mnist.py Normal file
View File

@@ -0,0 +1,112 @@
"""GMLVQ example using the MNIST dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
ImageGMLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = MNIST(
"~/datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
test_ds = MNIST(
"~/datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=4, batch_size=256)
test_loader = DataLoader(test_ds, num_workers=4, batch_size=256)
# Hyperparameters
num_classes = 10
prototypes_per_class = 10
hparams = dict(
input_dim=28 * 28,
latent_dim=28 * 28,
distribution=(num_classes, prototypes_per_class),
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = ImageGMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
# Callbacks
vis = VisImgComp(
data=train_ds,
num_columns=10,
show=False,
tensorboard=True,
random_data=100,
add_embedding=True,
embedding_data=200,
flatten_data=False,
)
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

94
examples/gmlvq_spiral.py Normal file
View File

@@ -0,0 +1,94 @@
"""GMLVQ example using the spiral dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
GMLVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=256)
# Hyperparameters
num_classes = 2
prototypes_per_class = 10
hparams = dict(
distribution=(num_classes, prototypes_per_class),
transfer_function="swish_beta",
transfer_beta=10.0,
proto_lr=0.1,
bb_lr=0.1,
input_dim=2,
latent_dim=2,
)
# Initialize the model
model = GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
)
# Callbacks
vis = VisGLVQ2D(
train_ds,
show_last_only=False,
block=False,
)
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=5,
frequency=5,
replace=True,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=1.0,
patience=5,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
pruning,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

65
examples/gng_iris.py Normal file
View File

@@ -0,0 +1,65 @@
"""Growing Neural Gas example using the Iris dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=42)
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
num_prototypes=5,
input_dim=2,
lr=0.1,
)
# Initialize the model
model = GrowingNeuralGas(
hparams,
prototypes_initializer=pt.initializers.ZCI(2),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Model summary
logging.info(model)
# Callbacks
vis = VisNG2D(data=train_loader)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

116
examples/gtlvq_mnist.py Normal file
View File

@@ -0,0 +1,116 @@
"""GTLVQ example using the MNIST dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
ImageGTLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = MNIST(
"~/datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
test_ds = MNIST(
"~/datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
]),
)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=256)
test_loader = DataLoader(test_ds, num_workers=0, batch_size=256)
# Hyperparameters
num_classes = 10
prototypes_per_class = 1
hparams = dict(
input_dim=28 * 28,
latent_dim=28,
distribution=(num_classes, prototypes_per_class),
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = ImageGTLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
#Use one batch of data for subspace initiator.
omega_initializer=pt.initializers.PCALinearTransformInitializer(
next(iter(train_loader))[0].reshape(256, 28 * 28)))
# Callbacks
vis = VisImgComp(
data=train_ds,
num_columns=10,
show=False,
tensorboard=True,
random_data=100,
add_embedding=True,
embedding_data=200,
flatten_data=False,
)
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer
# using GPUs here is strongly recommended!
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

76
examples/gtlvq_moons.py Normal file
View File

@@ -0,0 +1,76 @@
"""Localized-GTLVQ example using the Moons dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = DataLoader(
train_ds,
batch_size=256,
shuffle=True,
)
# Hyperparameters
# Latent_dim should be lower than input dim.
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
# Initialize the model
model = GTLVQ(hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

81
examples/knn_iris.py Normal file
View File

@@ -0,0 +1,81 @@
"""k-NN example using the Iris dataset from scikit-learn."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import KNN, VisGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
X, y = load_iris(return_X_y=True)
X = X[:, 0:3:2]
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=0.5,
random_state=42,
)
train_ds = pt.datasets.NumpyDataset(X_train, y_train)
test_ds = pt.datasets.NumpyDataset(X_test, y_test)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=16)
test_loader = DataLoader(test_ds, batch_size=16)
# Hyperparameters
hparams = dict(k=5)
# Initialize the model
model = KNN(hparams, data=train_ds)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
# Callbacks
vis = VisGLVQ2D(
data=(X_train, y_train),
resolution=200,
block=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=1,
callbacks=[
vis,
],
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
# This is only for visualization. k-NN has no training phase.
trainer.fit(model, train_loader)
# Recall
y_pred = model.predict(torch.tensor(X_train))
logging.info(y_pred)
# Test
trainer.test(model, dataloaders=test_loader)

118
examples/ksom_colors.py Normal file
View File

@@ -0,0 +1,118 @@
"""Kohonen Self Organizing Map."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Vis2DColorSOM(pl.Callback):
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
super().__init__()
self.title = title
self.fig = plt.figure(self.title)
self.data = data
self.pause_time = pause_time
def on_train_epoch_end(self, trainer, pl_module: KohonenSOM):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
h, w = pl_module._grid.shape[:2]
protos = pl_module.prototypes.view(h, w, 3)
ax.imshow(protos)
ax.axis("off")
# Overlay color names
d = pl_module.compute_distances(self.data)
wp = pl_module.predict_from_distances(d)
for i, iloc in enumerate(wp):
plt.text(
iloc[1],
iloc[0],
color_names[i],
ha="center",
va="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0),
)
if trainer.current_epoch != trainer.max_epochs - 1:
plt.pause(self.pause_time)
else:
plt.show(block=True)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=42)
# Prepare the data
hex_colors = [
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
]
color_names = [
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey", "olive", "purple", "orange"
]
colors = list(hex_to_rgb(hex_colors))
data = torch.Tensor(colors) / 255.0
train_ds = TensorDataset(data)
train_loader = DataLoader(train_ds, batch_size=8)
# Hyperparameters
hparams = dict(
shape=(18, 32),
alpha=1.0,
sigma=16,
lr=0.1,
)
# Initialize the model
model = KohonenSOM(
hparams,
prototypes_initializer=pt.initializers.RNCI(3),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 3)
# Model summary
logging.info(model)
# Callbacks
vis = Vis2DColorSOM(data=data)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=500,
callbacks=[
vis,
],
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

77
examples/lgmlvq_moons.py Normal file
View File

@@ -0,0 +1,77 @@
"""Localized-GMLVQ example using the Moons dataset."""
import argparse
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
# Hyperparameters
hparams = dict(
distribution=[1, 3],
input_dim=2,
latent_dim=2,
)
# Initialize the model
model = LGMLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
logging.info(model)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
log_every_n_steps=1,
max_epochs=1000,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,48 +0,0 @@
"""Limited Rank Matrix LVQ example using the Tecator dataset."""
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Tecator(root="~/datasets/", train=True)
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=32)
# Hyperparameters
hparams = dict(
nclasses=2,
prototypes_per_class=2,
input_dim=100,
latent_dim=2,
prototype_initializer=pt.components.SMI(train_ds),
lr=0.001,
)
# Initialize the model
model = pt.models.GMLVQ(hparams)
# Callbacks
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)
# Save the model
torch.save(model, "liramlvq_tecator.pt")
# Load a saved model
saved_model = torch.load("liramlvq_tecator.pt")
# Display the Lambda matrix
saved_model.show_lambda()

103
examples/lvqmln_iris.py Normal file
View File

@@ -0,0 +1,103 @@
"""LVQMLN example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
LVQMLN,
PruneLoserPrototypes,
VisSiameseGLVQ2D,
)
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.activation = torch.nn.Sigmoid()
def forward(self, x):
x = self.activation(self.dense1(x))
out = self.activation(self.dense2(x))
return out
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Reproducibility
seed_everything(seed=42)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
distribution=[3, 4, 5],
proto_lr=0.001,
bb_lr=0.001,
)
# Initialize the backbone
backbone = Backbone()
# Initialize the model
model = LVQMLN(
hparams,
prototypes_initializer=pt.initializers.SSCI(
train_ds,
transform=backbone,
),
backbone=backbone,
)
# Callbacks
vis = VisSiameseGLVQ2D(
data=train_ds,
map_protos=False,
border=0.1,
resolution=500,
axis_off=True,
)
pruning = PruneLoserPrototypes(
threshold=0.01,
idle_epochs=20,
prune_quota_per_epoch=2,
frequency=10,
verbose=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
],
log_every_n_steps=1,
max_epochs=1000,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -0,0 +1,68 @@
"""Median-LVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = DataLoader(
train_ds,
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
)
# Initialize the model
model = MedianLVQ(
hparams=dict(distribution=(3, 2), lr=0.01),
prototypes_initializer=pt.initializers.SSCI(train_ds),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
es = EarlyStopping(
monitor="train_acc",
min_delta=0.01,
patience=5,
mode="max",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,15 +1,34 @@
"""Neural Gas example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Prepare and pre-process the dataset
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
x_train = x_train[:, 0:3:2]
scaler = StandardScaler()
scaler.fit(x_train)
x_train = scaler.transform(x_train)
@@ -17,24 +36,39 @@ if __name__ == "__main__":
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(num_prototypes=30, lr=0.03)
hparams = dict(
num_prototypes=30,
input_dim=2,
lr=0.03,
)
# Initialize the model
model = pt.models.NeuralGas(hparams)
model = NeuralGas(
hparams,
prototypes_initializer=pt.core.ZCI(2),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Model summary
print(model)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisNG2D(data=train_ds)
vis = VisNG2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

68
examples/rslvq_iris.py Normal file
View File

@@ -0,0 +1,68 @@
"""RSLVQ example using the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
seed_everything(seed=42)
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
distribution=[2, 2, 3],
proto_lr=0.05,
lambd=0.1,
variance=1.0,
input_dim=2,
latent_dim=2,
bb_lr=0.01,
)
# Initialize the model
model = RSLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
detect_anomaly=True,
max_epochs=100,
log_every_n_steps=1,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -1,12 +1,22 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
"""Two fully connected layers with ReLU activation."""
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
@@ -14,51 +24,60 @@ class Backbone(torch.nn.Module):
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.relu = torch.nn.ReLU()
self.activation = torch.nn.Sigmoid()
def forward(self, x):
x = self.relu(self.dense1(x))
out = self.relu(self.dense2(x))
x = self.activation(self.dense1(x))
out = self.activation(self.dense2(x))
return out
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
train_ds = pt.datasets.Iris()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
seed_everything(seed=2)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
num_workers=0,
batch_size=150)
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
nclasses=3,
prototypes_per_class=2,
prototype_initializer=pt.components.SMI((x_train, y_train)),
proto_lr=0.001,
bb_lr=0.001,
distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the backbone
backbone = Backbone()
# Initialize the model
model = pt.models.SiameseGLVQ(
model = SiameseGLVQ(
hparams,
backbone_module=Backbone,
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)
# Model summary
print(model)
# Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=(x_train, y_train), border=0.1)
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -0,0 +1,85 @@
"""Siamese GTLVQ example using all four dimensions of the Iris dataset."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.activation = torch.nn.Sigmoid()
def forward(self, x):
x = self.activation(self.dense1(x))
out = self.activation(self.dense2(x))
return out
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Reproducibility
seed_everything(seed=2)
# Dataloaders
train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters
hparams = dict(
distribution=[1, 2, 3],
proto_lr=0.01,
bb_lr=0.01,
input_dim=2,
latent_dim=1,
)
# Initialize the backbone
backbone = Backbone(latent_size=hparams["input_dim"])
# Initialize the model
model = SiameseGTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone,
both_path_gradients=False,
)
# Callbacks
vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

124
examples/warm_starting.py Normal file
View File

@@ -0,0 +1,124 @@
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
import argparse
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch
from prototorch.models import (
GLVQ,
KNN,
GrowingNeuralGas,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = DataLoader(train_ds, batch_size=64, num_workers=0)
# Initialize the gng
gng = GrowingNeuralGas(
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
prototypes_initializer=pt.initializers.ZCI(2),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Callbacks
es = EarlyStopping(
monitor="loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer for GNG
trainer = pl.Trainer(
max_epochs=1000,
callbacks=[
es,
],
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(gng, train_loader)
# Hyperparameters
hparams = dict(
distribution=[],
lr=0.01,
)
# Warm-start prototypes
knn = KNN(dict(k=1), data=train_ds)
prototypes = gng.prototypes
plabels = knn.predict(prototypes)
# Initialize the model
model = GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.LCI(prototypes),
labels_initializer=pt.initializers.LLI(plabels),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
pruning = PruneLoserPrototypes(
threshold=0.02,
idle_epochs=2,
prune_quota_per_epoch=5,
frequency=1,
verbose=True,
)
es = EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=10,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
pruning,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -0,0 +1,100 @@
import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.y.callbacks import (
LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
# ##############################################################################
def main():
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
# Dataset
train_ds = pt.datasets.Iris()
# Dataloader
train_loader = DataLoader(
train_ds,
batch_size=32,
num_workers=0,
shuffle=True,
)
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
# Select Initializer
components_initializer = SMCI(train_ds)
# Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters(
lr=dict(components_layer=0.1, _omega=0),
input_dim=4,
distribution=dict(
num_classes=3,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
model = GMLVQ(hyperparameters)
print(model.hparams)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
stopping_criterion = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
num_classes=3,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
mode="max",
patience=10,
)
# Visualization Callback
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
], )
# Train
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,39 @@
from importlib.metadata import PackageNotFoundError, version
"""`models` plugin for the `prototorch` package."""
from .cbc import CBC
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
from .neural_gas import NeuralGas
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
GTLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
ImageGTLVQ,
SiameseGLVQ,
SiameseGMLVQ,
SiameseGTLVQ,
)
from .knn import KNN
from .lvq import (
LVQ1,
LVQ21,
MedianLVQ,
)
from .probabilistic import (
CELVQ,
RSLVQ,
SLVQ,
)
from .unsupervised import (
GrowingNeuralGas,
KohonenSOM,
NeuralGas,
)
from .vis import *
__version__ = "0.1.0"
__version__ = "1.0.0-a4"

View File

@@ -1,23 +1,219 @@
"""Abstract classes to be inherited by prototorch models."""
import logging
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
import torch.nn.functional as F
import torchmetrics
from prototorch.core.competitions import WTAC
from prototorch.core.components import (
AbstractComponents,
Components,
LabeledComponents,
)
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.core.pooling import stratified_min_pooling
from prototorch.nn.wrappers import LambdaLayer
class AbstractLightningModel(pl.LightningModule):
class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts.
hparams:
- lr: learning rate
kwargs:
- optimizer: optimizer class
- lr_scheduler: learning rate scheduler class
- lr_scheduler_kwargs: learning rate scheduler kwargs
"""
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("lr", 0.01)
# Default config
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"])
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
else:
return optimizer
def reconfigure_optimizers(self):
if self.trainer:
self.trainer.strategy.setup_optimizers(self.trainer)
else:
logging.warning("No trainer to reconfigure optimizers!")
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
class AbstractPrototypeModel(AbstractLightningModel):
class PrototypeModel(ProtoTorchBolt):
"""Abstract Prototype Model
kwargs:
- distance_fn: distance function
"""
proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel):
proto_layer: Components
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
prototypes_initializer = kwargs.get("prototypes_initializer", None)
if prototypes_initializer is not None:
self.proto_layer = Components(
self.hparams["num_prototypes"],
initializer=prototypes_initializer,
)
def compute_distances(self, x):
protos = self.proto_layer().type_as(x)
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
return distances
class SupervisedPrototypeModel(PrototypeModel):
proto_layer: LabeledComponents
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
distribution = hparams.get("distribution", None)
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if not skip_proto_layer:
# when subclasses do not need a customized prototype layer
if prototypes_initializer is not None:
# when building a new model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
proto_shape = self.proto_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
self.competition_layer = WTAC()
@property
def prototype_labels(self):
return self.proto_layer.labels.detach().cpu()
@property
def num_classes(self):
return self.proto_layer.num_classes
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self.compute_distances(x)
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = F.softmin(winning, dim=1)
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
_, plabels = self.proto_layer()
y_pred = self.competition_layer(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self.compute_distances(x)
y_pred = self.predict_from_distances(distances)
return y_pred
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def test_step(self, batch, batch_idx):
x, targets = batch
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy)

View File

@@ -0,0 +1,152 @@
"""Lightning Callbacks."""
import logging
from typing import TYPE_CHECKING
import pytorch_lightning as pl
import torch
from prototorch.core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology
if TYPE_CHECKING:
from prototorch.models import GLVQ, GrowingNeuralGas
class PruneLoserPrototypes(pl.Callback):
def __init__(
self,
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=-1,
frequency=1,
replace=False,
prototypes_initializer=None,
verbose=False,
):
self.threshold = threshold # minimum win ratio
self.idle_epochs = idle_epochs # epochs to wait before pruning
self.prune_quota_per_epoch = prune_quota_per_epoch
self.frequency = frequency
self.replace = replace
self.verbose = verbose
self.prototypes_initializer = prototypes_initializer
def on_train_epoch_end(self, trainer, pl_module: "GLVQ"):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
if (trainer.current_epoch + 1) % self.frequency:
return None
ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold]
to_prune = to_prune_tensor.tolist()
prune_labels = pl_module.prototype_labels[to_prune]
if self.prune_quota_per_epoch > 0:
to_prune = to_prune[:self.prune_quota_per_epoch]
prune_labels = prune_labels[:self.prune_quota_per_epoch]
if len(to_prune) > 0:
logging.debug(f"\nPrototype win ratios: {ratios}")
logging.debug(f"Pruning prototypes at: {to_prune}")
logging.debug(f"Corresponding labels are: {prune_labels.tolist()}")
cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune)
if self.replace:
labels, counts = torch.unique(prune_labels,
sorted=True,
return_counts=True)
distribution = dict(zip(labels.tolist(), counts.tolist()))
logging.info(f"Re-adding pruned prototypes...")
logging.debug(f"distribution={distribution}")
pl_module.add_prototypes(
distribution=distribution,
components_initializer=self.prototypes_initializer)
new_num_protos = pl_module.num_prototypes
logging.info(f"`num_prototypes` changed from {cur_num_protos} "
f"to {new_num_protos}.")
return True
class PrototypeConvergence(pl.Callback):
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
self.min_delta = min_delta
self.idle_epochs = idle_epochs # epochs to wait
self.verbose = verbose
def on_train_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
logging.info("Stopping...")
# TODO
return True
class GNGCallback(pl.Callback):
"""GNG Callback.
Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_train_epoch_end(
self,
trainer: pl.Trainer,
pl_module: "GrowingNeuralGas",
):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
neighbors = topology.get_neighbors(worst)[0]
if len(neighbors) == 0:
logging.log(level=20, msg="No neighbor-pairs found!")
return
neighbors_errors = errors[neighbors]
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbor])
# Add component
pl_module.proto_layer.add_components(
1,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)),
)
# Adjust Topology
topology.add_prototype()
topology.add_connection(worst, -1)
topology.add_connection(worst_neighbor, -1)
topology.remove_connection(worst, worst_neighbor)
# New errors
worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.strategy.setup_optimizers(trainer)

View File

@@ -1,165 +1,80 @@
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from prototorch.components.components import Components
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents
from prototorch.core.initializers import RandomReasoningsInitializer
from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer
from .glvq import SiameseGLVQ
from .mixins import ImagePrototypesMixin
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
similarities = cosine_similarity(x, y)
return (similarities + 1.0) / 2.0
def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y):
d = euclidean_distance(x, y)
return torch.exp(-d * 3)
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, n_components, n_classes, n_replicas=1):
super().__init__()
self.n_replicas = n_replicas
self.n_classes = n_classes
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
probabilities_init.uniform_(0.4, 0.6)
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
epsilon = torch.finfo(pk.dtype).eps
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs
class CBC(pl.LightningModule):
class CBC(SiameseGLVQ):
"""Classification-By-Components."""
def __init__(self,
hparams,
margin=0.1,
backbone_class=torch.nn.Identity,
similarity=euclidean_similarity,
**kwargs):
super().__init__()
self.save_hyperparameters(hparams)
self.margin = margin
self.component_layer = Components(self.hparams.num_components,
self.hparams.component_initializer)
# self.similarity = CosineSimilarity()
self.similarity = similarity
self.backbone = backbone_class()
self.backbone_dependent = backbone_class().requires_grad_(False)
n_components = self.components.shape[0]
self.reasoning_layer = ReasoningLayer(n_components=n_components,
n_classes=self.hparams.nclasses)
self.train_acc = torchmetrics.Accuracy()
proto_layer: ReasoningComponents
@property
def components(self):
return self.component_layer.components.detach().cpu()
def __init__(self, hparams, **kwargs):
super().__init__(hparams, skip_proto_layer=True, **kwargs)
@property
def reasonings(self):
return self.reasoning_layer.reasonings.cpu()
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
components_initializer = kwargs.get("components_initializer", None)
reasonings_initializer = kwargs.get("reasonings_initializer",
RandomReasoningsInitializer())
self.components_layer = ReasoningComponents(
self.hparams["distribution"],
components_initializer=components_initializer,
reasonings_initializer=reasonings_initializer,
)
self.similarity_layer = LambdaLayer(similarity_fn)
self.competition_layer = CBCC()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
# Namespace hook
self.proto_layer = self.components_layer
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
self.loss = MarginLoss(self.hparams["margin"])
def forward(self, x):
self.sync_backbones()
protos = self.component_layer()
components, reasonings = self.components_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
detections = self.similarity(latent_x, latent_protos)
probs = self.reasoning_layer(detections)
self.backbone.requires_grad_(self.both_path_gradients)
latent_components = self.backbone(components)
self.backbone.requires_grad_(True)
detections = self.similarity_layer(latent_x, latent_components)
probs = self.competition_layer(detections, reasonings)
return probs
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
y_pred = self(x)
nclasses = self.reasoning_layer.n_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
self.log("train_loss", loss)
self.train_acc(y_pred, y_true)
self.log(
"acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
num_classes = self.num_classes
y_true = F.one_hot(y.long(), num_classes=num_classes)
loss = self.loss(y_pred, y_true).mean()
return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
preds = torch.argmax(y_pred, dim=1)
accuracy = torchmetrics.functional.accuracy(preds.int(),
batch[1].int())
self.log("train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
return train_loss
def predict(self, x):
with torch.no_grad():
y_pred = self(x)
y_pred = torch.argmax(y_pred, dim=1)
return y_pred.numpy()
return y_pred
class ImageCBC(CBC):
class ImageCBC(ImagePrototypesMixin, CBC):
"""CBC model that constrains the components to the range [0, 1] by
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
self.component_layer.prototypes.data.clamp_(0.0, 1.0)

130
prototorch/models/extras.py Normal file
View File

@@ -0,0 +1,130 @@
"""prototorch.models.extras
Modules not yet available in prototorch go here temporarily.
"""
import torch
from prototorch.core.similarities import gaussian
def rank_scaled_gaussian(distances, lambd):
order = torch.argsort(distances, dim=1)
ranks = torch.argsort(order, dim=1)
return torch.exp(-torch.exp(-ranks / lambd) * distances)
def orthogonalization(tensors):
"""Orthogonalization via polar decomposition """
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def ltangent_distance(x, y, omegas):
r"""Localized Tangent distance.
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
:param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor`
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p
projected_y = torch.diagonal(y @ p).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
distances = distances.permute(1, 0)
return distances
class GaussianPrior(torch.nn.Module):
def __init__(self, variance):
super().__init__()
self.variance = variance
def forward(self, distances):
return gaussian(distances, self.variance)
class RankScaledGaussianPrior(torch.nn.Module):
def __init__(self, lambd):
super().__init__()
self.lambd = lambd
def forward(self, distances):
return rank_scaled_gaussian(distances, self.lambd)
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit
self.num_prototypes = num_prototypes
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
self.age = torch.zeros_like(self.cmat)
def forward(self, d):
order = torch.argsort(d, dim=1)
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.cmat[i1][i0] = 1
self.age[i0][i1] = 0
self.age[i1][i0] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.age[i1][self.cmat[i1] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbors(self, position):
return torch.where(self.cmat[position])
def add_prototype(self):
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
new_cmat[:-1, :-1] = self.cmat
self.cmat = new_cmat
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
new_age[:-1, :-1] = self.age
self.age = new_age
def add_connection(self, a, b):
self.cmat[a][b] = 1
self.cmat[b][a] = 1
self.age[a][b] = 0
self.age[b][a] = 0
def remove_connection(self, a, b):
self.cmat[a][b] = 0
self.cmat[b][a] = 0
self.age[a][b] = 0
self.age[b][a] = 0
def extra_repr(self):
return f"(agelimit): ({self.agelimit})"

View File

@@ -1,90 +1,107 @@
"""Models based on the GLVQ framework."""
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.core.competitions import wtac
from prototorch.core.distances import (
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.core.initializers import EyeLinearTransformInitializer
from prototorch.core.losses import (
GLVQLoss,
lvq1_loss,
lvq21_loss,
)
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from .abstract import AbstractPrototypeModel
from .abstract import SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
from .mixins import ImagePrototypesMixin
class GLVQ(AbstractPrototypeModel):
class GLVQ(SupervisedPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__()
super().__init__(hparams, **kwargs)
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
self.hparams.setdefault("optimizer", torch.optim.Adam)
self.hparams.setdefault("transfer_function", "identity")
# Default hparams
self.hparams.setdefault("margin", 0.0)
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
self.proto_layer = LabeledComponents(
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
initializer=self.hparams.prototype_initializer)
# Loss
self.loss = GLVQLoss(
margin=self.hparams["margin"],
transfer_fn=self.hparams["transfer_fn"],
beta=self.hparams["transfer_beta"],
)
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy()
# def on_save_checkpoint(self, checkpoint):
# if "prototype_win_ratios" in checkpoint["state_dict"]:
# del checkpoint["state_dict"]["prototype_win_ratios"]
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device),
)
def forward(self, x):
protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos)
return dis
def on_train_epoch_start(self):
self.initialize_prototype_win_ratios()
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
x, y = train_batch
x = x.view(x.size(0), -1) # flatten
dis = self(x)
plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
def log_prototype_win_ratios(self, distances):
batch_size = len(distances)
prototype_wc = torch.zeros(
self.num_prototypes,
dtype=torch.long,
device=self.device,
)
wi, wc = torch.unique(
distances.min(dim=-1).indices,
sorted=True,
return_counts=True,
)
prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([
self.prototype_win_ratios,
prototype_wr,
])
# Compute training accuracy
with torch.no_grad():
preds = wtac(dis, plabels)
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x)
_, plabels = self.proto_layer()
loss = self.loss(out, y, plabels)
return out, loss
self.train_acc(preds.int(), y.int())
# `.int()` because FloatTensors are assumed to be class probabilities
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log_prototype_win_ratios(out)
self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc")
return train_loss
# Logging
self.log("train_loss", loss)
self.log("acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def validation_step(self, batch, batch_idx):
out, val_loss = self.shared_step(batch, batch_idx)
self.log("val_loss", val_loss)
self.log_acc(out, batch[-1], tag="val_acc")
return val_loss
return loss
def test_step(self, batch, batch_idx):
out, test_loss = self.shared_step(batch, batch_idx)
self.log_acc(out, batch[-1], tag="test_acc")
return test_loss
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()
class ImageGLVQ(GLVQ):
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)
def test_epoch_end(self, outputs):
test_loss = 0.0
for batch_loss in outputs:
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
class SiameseGLVQ(GLVQ):
@@ -95,149 +112,73 @@ class SiameseGLVQ(GLVQ):
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
sync=True,
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
self.backbone_dependent = backbone_module(
**backbone_params).requires_grad_(False)
self.sync = sync
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def __init__(
self,
hparams,
backbone=torch.nn.Identity(),
both_path_gradients=False,
**kwargs,
):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone
self.both_path_gradients = both_path_gradients
def configure_optimizers(self):
optim = self.hparams.optimizer
proto_opt = optim(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
return proto_opt, bb_opt
proto_opt = self.optimizer(
self.proto_layer.parameters(),
lr=self.hparams["proto_lr"],
)
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
return proto_opt
optimizers = [proto_opt]
if self.lr_scheduler is not None:
schedulers = []
for optimizer in optimizers:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
schedulers.append(scheduler)
return optimizers, schedulers
else:
return optimizers
def forward(self, x):
if self.sync:
self.sync_backbones()
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(bb_grad)
distances = self.distance_layer(latent_x, latent_protos)
return distances
def predict_latent(self, x, map_protos=True):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
self.eval()
with torch.no_grad():
protos, plabels = self.proto_layer()
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
if map_protos:
protos = self.backbone(protos)
d = self.distance_layer(x, protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
return y_pred
class GRLVQ(GLVQ):
"""Generalized Relevance Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
def forward(self, x):
protos, _ = self.proto_layer()
dis = omega_distance(x, protos, torch.diag(self.relevances))
return dis
def backbone(self, x):
return x @ torch.diag(self.relevances)
@property
def relevance_profile(self):
return self.relevances.detach().cpu()
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
latent_protos = protos @ torch.diag(self.relevances)
d = squared_euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.omega_layer = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
# Namespace hook for the visualization callbacks to work
self.backbone = self.omega_layer
@property
def omega_matrix(self):
return self.omega_layer.weight.detach().cpu()
@property
def lambda_matrix(self):
omega = self.omega_layer.weight # (latent_dim, input_dim)
lam = omega.T @ omega
return lam.detach().cpu()
def show_lambda(self):
import matplotlib.pyplot as plt
title = "Lambda matrix"
plt.figure(title)
plt.title(title)
plt.imshow(self.lambda_matrix, cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show(block=True)
def forward(self, x):
protos, _ = self.proto_layer()
latent_x = self.omega_layer(x)
latent_protos = self.omega_layer(protos)
dis = squared_euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
latent_protos = self.omega_layer(protos)
d = squared_euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
class LVQMLN(GLVQ):
class LVQMLN(SiameseGLVQ):
"""Learning Vector Quantization Multi-Layer Network.
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
@@ -246,27 +187,228 @@ class LVQMLN(GLVQ):
rather in the embedding space.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
with torch.no_grad():
protos = self.backbone(self.proto_layer()[0])
self.proto_layer.load_state_dict({"_components": protos}, strict=False)
def forward(self, x):
def compute_distances(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
dis = euclidean_distance(latent_x, latent_protos)
return dis
distances = self.distance_layer(latent_x, latent_protos)
return distances
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space."""
class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
_relevances: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Additional parameters
relevances = torch.ones(self.hparams["input_dim"], device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
name="relevance scaling")
@property
def relevance_profile(self):
return self._relevances.detach().cpu()
def extra_repr(self):
return f"(relevances): (shape: {tuple(self._relevances.shape)})"
class SiameseGMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Override the backbone
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
self.backbone = LinearTransform(
self.hparams["input_dim"],
self.hparams["latent_dim"],
initializer=omega_initializer,
)
@property
def omega_matrix(self):
return self.backbone.weights
@property
def lambda_matrix(self):
omega = self.backbone.weights # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization.
Implemented as a regular GLVQ network that simply uses a different distance
function. This makes it easier to implement a localized variant.
"""
# Parameters
_omega: torch.Tensor
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega_initializer = kwargs.get(
"omega_initializer",
EyeLinearTransformInitializer(),
)
omega = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(
lambda x: x @ self._omega,
name="omega matrix",
)
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
return distances
def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})"
class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Re-register `_omega` to override the one from the super class.
omega = torch.randn(
self.num_prototypes,
self.hparams["input_dim"],
self.hparams["latent_dim"],
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
class GTLVQ(LGMLVQ):
"""Localized and Generalized Tangent Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega_initializer = kwargs.get("omega_initializer")
if omega_initializer is not None:
subspace = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
omega = torch.repeat_interleave(
subspace.unsqueeze(0),
self.num_prototypes,
dim=0,
)
else:
omega = torch.rand(
self.num_prototypes,
self.hparams["input_dim"],
self.hparams["latent_dim"],
device=self.device,
)
# Re-register `_omega` to override the one from the super class.
self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
latent_protos, plabels = self.proto_layer()
d = euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
self._omega.copy_(orthogonalization(self._omega))
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
"""Generalized Tangent Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq1_loss)
self.optimizer = torch.optim.SGD
class GLVQ21(GLVQ):
"""Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = LossLayer(lvq21_loss)
self.optimizer = torch.optim.SGD
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
"""GMLVQ for training on image data.
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
"""GTLVQ for training on image data.
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))

45
prototorch/models/knn.py Normal file
View File

@@ -0,0 +1,45 @@
"""ProtoTorch KNN model."""
import warnings
from prototorch.core.competitions import KNNC
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
LiteralCompInitializer,
LiteralLabelsInitializer,
)
from prototorch.utils.utils import parse_data_arg
from .abstract import SupervisedPrototypeModel
class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, skip_proto_layer=True, **kwargs)
# Default hparams
self.hparams.setdefault("k", 1)
data = kwargs.get("data", None)
if data is None:
raise ValueError("KNN requires data, but was not provided!")
data, targets = parse_data_arg(data)
# Layers
self.proto_layer = LabeledComponents(
distribution=len(data) * [1],
components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k)
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self, train_batch, batch_idx):
warnings.warn("k-NN has no training, skipping!")
return -1
def configure_optimizers(self):
return None

138
prototorch/models/lvq.py Normal file
View File

@@ -0,0 +1,138 @@
"""LVQ models that are optimized using non-gradient methods."""
import logging
from collections import OrderedDict
from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer
from .glvq import GLVQ
from .mixins import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
d = self.compute_distances(xi.view(1, -1))
preds = self.competition_layer(d, plabels)
w = d.argmin(1)
if yi == preds:
shift = xi - protos[w]
else:
shift = protos[w] - xi
updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams["lr"] * shift)
self.proto_layer.load_state_dict(
OrderedDict(_components=updated_protos),
strict=False,
)
logging.debug(f"dis={dis}")
logging.debug(f"y={y}")
# Logging
self.log_acc(dis, y, tag="train_acc")
return None
class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
xi = xi.view(1, -1)
yi = yi.view(1, )
d = self.compute_distances(xi)
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn)
self.proto_layer.load_state_dict(
OrderedDict(_components=updated_protos),
strict=False,
)
# Logging
self.log_acc(dis, y, tag="train_acc")
return None
class MedianLVQ(NonGradientMixin, GLVQ):
"""Median LVQ
# TODO Avoid computing distances over and over
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer(
get_activation(self.hparams["transfer_fn"]))
def _f(self, x, y, protos, plabels):
d = self.distance_layer(x, protos)
dp, dm = _get_dp_dm(d, y, plabels, with_indices=False)
mu = (dp - dm) / (dp + dm)
negative_mu = -1.0 * mu
f = self.transfer_layer(
negative_mu,
beta=self.hparams["transfer_beta"],
) + 1.0
return f
def expectation(self, x, y, protos, plabels):
f = self._f(x, y, protos, plabels)
gamma = f / f.sum()
return gamma
def lower_bound(self, x, y, protos, plabels, gamma):
f = self._f(x, y, protos, plabels)
lower_bound = (gamma * f.log()).sum()
return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
for i, _ in enumerate(protos):
# Expectation step
gamma = self.expectation(x, y, protos, plabels)
lower_bound = self.lower_bound(x, y, protos, plabels, gamma)
# Maximization step
_protos = protos + 0
for k, xk in enumerate(x):
_protos[i] = xk
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound:
logging.debug(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict(
OrderedDict(_components=_protos),
strict=False,
)
break
# Logging
self.log_acc(dis, y, tag="train_acc")
return None

View File

@@ -0,0 +1,35 @@
import pytorch_lightning as pl
import torch
from prototorch.core.components import Components
class ProtoTorchMixin(pl.LightningModule):
"""All mixins are ProtoTorchMixins."""
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@@ -1,69 +0,0 @@
import torch
from prototorch.components import Components
from prototorch.components import initializers as cinit
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import NeuralGasEnergy
from .abstract import AbstractPrototypeModel
class EuclideanDistance(torch.nn.Module):
def forward(self, x, y):
return euclidean_distance(x, y)
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit
self.num_prototypes = num_prototypes
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
self.age = torch.zeros_like(self.cmat)
def forward(self, d):
order = torch.argsort(d, dim=1)
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.age[i0][i1] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
def extra_repr(self):
return f"agelimit: {self.agelimit}"
class NeuralGas(AbstractPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1)
self.hparams.setdefault("prototype_initializer",
cinit.ZerosInitializer(self.hparams.input_dim))
self.proto_layer = Components(
self.hparams.num_prototypes,
initializer=self.hparams.prototype_initializer)
self.distance_layer = EuclideanDistance()
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit,
num_prototypes=self.hparams.num_prototypes,
)
def training_step(self, train_batch, batch_idx):
x = train_batch[0]
protos = self.proto_layer()
d = self.distance_layer(x, protos)
cost, order = self.energy_layer(d)
self.topology_layer(d)
return cost

View File

@@ -0,0 +1,131 @@
"""Probabilistic GLVQ methods"""
import torch
from prototorch.core.losses import nllr_loss, rslvq_loss
from prototorch.core.pooling import (
stratified_min_pooling,
stratified_sum_pooling,
)
from prototorch.nn.wrappers import LossLayer
from .extras import GaussianPrior, RankScaledGaussianPrior
from .glvq import GLVQ, SiameseGMLVQ
class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Loss
self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x) # [None, num_protos]
_, plabels = self.proto_layer()
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long())
loss = batch_loss.sum()
return out, loss
class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs)
self.rejection_confidence = rejection_confidence
self._conditional_distribution = None
def forward(self, x):
distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
plabels = self.proto_layer._labels
if isinstance(plabels, torch.LongTensor) or isinstance(
plabels, torch.cuda.LongTensor): # type: ignore
y_pred = stratified_sum_pooling(posterior, plabels) # type: ignore
else:
raise ValueError("Labels must be LongTensor.")
return y_pred
def predict(self, x):
y_pred = self.forward(x)
confidence, prediction = torch.max(y_pred, dim=1)
prediction[confidence < self.rejection_confidence] = -1
return prediction
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
_, plabels = self.proto_layer()
batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum()
return loss
def conditional_distribution(self, distances):
"""Conditional distribution of distances."""
if self._conditional_distribution is None:
raise ValueError("Conditional distribution is not set.")
return self._conditional_distribution(distances)
class SLVQ(ProbabilisticLVQ):
"""Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(nllr_loss)
class RSLVQ(ProbabilisticLVQ):
"""Robust Soft Learning Vector Quantization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(rslvq_loss)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
"""Probabilistic Learning Vector Quantization.
TODO: Use Backbone LVQ instead
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("lambda", 1.0)
lam = self.hparams.get("lambda", 1.0)
self.conditional_distribution = RankScaledGaussianPrior(lam)
self.loss = torch.nn.KLDivLoss()
# FIXME
# def training_step(self, batch, batch_idx, optimizer_idx=None):
# x, y = batch
# y_pred = self(x)
# batch_loss = self.loss(y_pred, y)
# loss = batch_loss.sum()
# return loss

View File

@@ -0,0 +1,155 @@
"""Unsupervised prototype learning algorithms."""
import numpy as np
import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy
from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
from .mixins import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
"""Kohonen Self-Organizing-Map.
TODO Allow non-2D grids
"""
_grid: torch.Tensor
def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape")
# Ignore `num_prototypes`
hparams["num_prototypes"] = h * w
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("alpha", 0.3)
self.hparams.setdefault("sigma", max(h, w) / 2.0)
# Additional parameters
x, y = torch.arange(h), torch.arange(w)
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
self.register_buffer("_grid", grid)
self._sigma = self.hparams.sigma
self._lr = self.hparams.lr
def predict_from_distances(self, distances):
grid = self._grid.view(-1, 2)
wp = wtac(distances, grid)
return wp
def training_step(self, train_batch, batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
d = self.compute_distances(x)
wp = self.predict_from_distances(d)
grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid)
nh = torch.exp(-gd / self._sigma**2)
protos = self.proto_layer()
diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict(
{"_components": updated_protos},
strict=False,
)
def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs)
def extra_repr(self):
return f"(grid): (shape: {tuple(self._grid.shape)})"
class HeskesSOM(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
def training_step(self, train_batch, batch_idx):
# TODO Implement me!
raise NotImplementedError()
class NeuralGas(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("age_limit", 10)
self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"])
self.topology_layer = ConnectionTopology(
agelimit=self.hparams["age_limit"],
num_prototypes=self.hparams["num_prototypes"],
)
def training_step(self, train_batch, batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
d = self.compute_distances(x)
loss, _ = self.energy_layer(d)
self.topology_layer(d)
self.log("loss", loss)
return loss
class GrowingNeuralGas(NeuralGas):
errors: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Defaults
self.hparams.setdefault("step_reduction", 0.5)
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
errors = torch.zeros(
self.hparams["num_prototypes"],
device=self.device,
)
self.register_buffer("errors", errors)
def training_step(self, train_batch, _batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
d = self.compute_distances(x)
loss, order = self.energy_layer(d)
winner = order[:, 0]
mask = torch.zeros_like(d)
mask[torch.arange(len(mask)), winner] = 1.0
dp = d * mask
self.errors += torch.sum(dp * dp)
self.errors *= self.hparams["step_reduction"]
self.topology_layer(d)
self.log("loss", loss)
return loss
def configure_callbacks(self):
return [
GNGCallback(
reduction=self.hparams["insert_reduction"],
freq=self.hparams["insert_freq"],
)
]

View File

@@ -1,324 +1,111 @@
"""Visualization Callbacks."""
import os
import warnings
from typing import Sized
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import (gif_from_dir, make_directory,
prettify_string)
from prototorch.utils.colors import get_colors, get_legend_handles
from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, Dataset
class VisWeights(pl.Callback):
"""Abstract weight visualization callback."""
def __init__(
self,
data=None,
ignore_last_output_row=False,
label_map=None,
project_mesh=False,
project_protos=False,
voronoi=False,
axis_off=True,
cmap="viridis",
show=True,
display_logs=True,
display_logs_settings={},
pause_time=0.5,
border=1,
resolution=10,
interval=False,
save=False,
snap=True,
save_dir="./img",
make_gif=False,
make_mp4=False,
verbose=True,
dpi=500,
fps=5,
figsize=(11, 8.5), # standard paper in inches
prefix="",
distance_layer_index=-1,
**kwargs,
):
super().__init__(**kwargs)
self.data = data
self.ignore_last_output_row = ignore_last_output_row
self.label_map = label_map
self.voronoi = voronoi
self.axis_off = True
self.project_mesh = project_mesh
self.project_protos = project_protos
self.cmap = cmap
self.show = show
self.display_logs = display_logs
self.display_logs_settings = display_logs_settings
self.pause_time = pause_time
self.border = border
self.resolution = resolution
self.interval = interval
self.save = save
self.snap = snap
self.save_dir = save_dir
self.make_gif = make_gif
self.make_mp4 = make_mp4
self.verbose = verbose
self.dpi = dpi
self.fps = fps
self.figsize = figsize
self.prefix = prefix
self.distance_layer_index = distance_layer_index
self.title = "Weights Visualization"
make_directory(self.save_dir)
def _skip_epoch(self, epoch):
if self.interval:
if epoch % self.interval != 0:
return True
return False
def _clean_and_setup_ax(self):
ax = self.ax
if not self.snap:
ax.cla()
ax.set_title(self.title)
if self.axis_off:
ax.axis("off")
def _savefig(self, fignum, orientation="horizontal"):
figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png"
figsize = self.figsize
if orientation == "vertical":
figsize = figsize[::-1]
elif orientation == "horizontal":
pass
else:
pass
self.fig.set_size_inches(figsize, forward=False)
self.fig.savefig(figname, dpi=self.dpi)
def _show_and_save(self, epoch):
if self.show:
plt.pause(self.pause_time)
if self.save:
self._savefig(epoch)
if self.snap:
self.camera.snap()
def _display_logs(self, ax, epoch, logs):
if self.display_logs:
settings = dict(
loc="lower right",
# padding between the text and bounding box
pad=0.5,
# padding between the bounding box and the axes
borderpad=1.0,
# https://matplotlib.org/api/text_api.html#matplotlib.text.Text
prop=dict(
fontfamily="monospace",
fontweight="medium",
fontsize=12,
),
)
# Override settings with self.display_logs_settings.
settings = {**settings, **self.display_logs_settings}
log_string = f"""Epoch: {epoch:04d},
val_loss: {logs.get('val_loss', np.nan):.03f},
val_acc: {logs.get('val_acc', np.nan):.03f},
loss: {logs.get('loss', np.nan):.03f},
acc: {logs.get('acc', np.nan):.03f}
"""
log_string = prettify_string(log_string, end="")
# https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText
anchored_text = AnchoredText(log_string, **settings)
self.ax.add_artist(anchored_text)
def on_train_start(self, trainer, pl_module, logs={}):
self.fig = plt.figure(self.title)
self.fig.set_size_inches(self.figsize, forward=False)
self.ax = self.fig.add_subplot(111)
self.camera = Camera(self.fig)
def on_train_end(self, trainer, pl_module, logs={}):
if self.make_gif:
gif_from_dir(directory=self.save_dir,
prefix=self.prefix,
duration=1.0 / self.fps)
if self.snap and self.make_mp4:
animation = self.camera.animate()
vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4")
if self.verbose:
print(f"Saving mp4 under {vid}.")
animation.save(vid, fps=self.fps, dpi=self.dpi)
class VisPointProtos(VisWeights):
"""Visualization of prototypes.
.. TODO::
Still in Progress.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.title = "Point Prototypes Visualization"
self.data_scatter_settings = {
"marker": "o",
"s": 30,
"edgecolor": "k",
"cmap": self.cmap,
}
self.protos_scatter_settings = {
"marker": "D",
"s": 50,
"edgecolor": "k",
"cmap": self.cmap,
}
def on_epoch_start(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
if self._skip_epoch(epoch):
return True
self._clean_and_setup_ax()
protos = pl_module.prototypes
labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy()
if self.project_protos:
protos = self.model.projection(protos).numpy()
color_map = color_scheme(n=len(set(labels)),
cmap=self.cmap,
zero_indexed=True)
# TODO Get rid of the assumption y values in [0, num_of_classes]
label_colors = [color_map[l] for l in labels]
if self.data is not None:
x, y = self.data
# TODO Get rid of the assumption y values in [0, num_of_classes]
y_colors = [color_map[l] for l in y]
# x = self.model.projection(x)
if not isinstance(x, np.ndarray):
x = x.numpy()
# Plot data points.
self.ax.scatter(x[:, 0],
x[:, 1],
c=y_colors,
**self.data_scatter_settings)
# Paint decision regions.
if self.voronoi:
border = self.border
resolution = self.resolution
x = np.vstack((x, protos))
x_min, x_max = x[:, 0].min(), x[:, 0].max()
y_min, y_max = x[:, 1].min(), x[:, 1].max()
x_min, x_max = x_min - border, x_max + border
y_min, y_max = y_min - border, y_max + border
try:
xx, yy = np.meshgrid(
np.arange(x_min, x_max, (x_max - x_min) / resolution),
np.arange(y_min, y_max, (x_max - x_min) / resolution),
)
except ValueError as ve:
print(ve)
raise ValueError(f"x_min: {x_min}, x_max: {x_max}. "
f"x_min - x_max is {x_max - x_min}.")
except MemoryError as me:
print(me)
raise ValueError("Too many points. "
"Try reducing the resolution.")
mesh_input = np.c_[xx.ravel(), yy.ravel()]
# Predict mesh labels.
if self.project_mesh:
mesh_input = self.model.projection(mesh_input)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions.
self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.ax.set_xlim(left=x_min + 0, right=x_max - 0)
self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
# Plot prototypes.
self.ax.scatter(protos[:, 0],
protos[:, 1],
c=label_colors,
**self.protos_scatter_settings)
# self._show_and_save(epoch)
def on_epoch_end(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
self._display_logs(self.ax, epoch, logs)
self._show_and_save(epoch)
class Vis2DAbstract(pl.Callback):
def __init__(self,
data,
data=None,
title="Prototype Visualization",
cmap="viridis",
border=1,
resolution=50,
xlabel="Data dimension 1",
ylabel="Data dimension 2",
legend_labels=None,
border=0.1,
resolution=100,
flatten_data=True,
axis_off=False,
show_protos=True,
show=True,
tensorboard=False,
show_last_only=False,
pause_time=0.1,
save=False,
save_dir="./img",
fig_size=(5, 4),
dpi=500,
block=False):
super().__init__()
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
x = x.view(len(data), -1) # flatten
if data:
if isinstance(data, Dataset):
if isinstance(data, Sized):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
else:
# TODO: Add support for non-sized datasets
raise NotImplementedError(
"Data must be a dataset with a __len__ method.")
elif isinstance(data, DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
else:
x, y = data
if flatten_data:
x = x.reshape(len(x), -1)
self.x_train = x
self.y_train = y
else:
x, y = data
self.x_train = x
self.y_train = y
self.x_train = None
self.y_train = None
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.legend_labels = legend_labels
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
self.resolution = resolution
self.axis_off = axis_off
self.show_protos = show_protos
self.show = show
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.pause_time = pause_time
self.save = save
self.save_dir = save_dir
self.fig_size = fig_size
self.dpi = dpi
self.block = block
if save:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def precheck(self, trainer):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return
return False
return True
def setup_ax(self, xlabel=None, ylabel=None):
def setup_ax(self):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
if xlabel:
ax.set_xlabel("Data dimension 1")
if ylabel:
ax.set_ylabel("Data dimension 2")
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
if self.axis_off:
ax.axis("off")
return ax
def get_mesh_input(self, x):
x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution),
np.arange(y_min, y_max, 1 / self.resolution))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
return mesh_input, xx, yy
def plot_data(self, ax, x, y):
ax.scatter(
x[:, 0],
@@ -351,94 +138,140 @@ class Vis2DAbstract(pl.Callback):
def log_and_display(self, trainer, pl_module):
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=True)
if self.save:
plt.tight_layout()
self.fig.set_size_inches(*self.fig_size, forward=False)
plt.savefig(f"{self.save_dir}/{trainer.current_epoch}.png",
dpi=self.dpi)
if self.show:
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=self.block)
def on_train_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
self.visualize(pl_module)
self.log_and_display(trainer, pl_module)
def on_train_end(self, trainer, pl_module):
plt.show()
plt.close()
def visualize(self, pl_module):
raise NotImplementedError
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisSiameseGLVQ2D(Vis2DAbstract):
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
device = pl_module.device
with torch.no_grad():
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
x_train = x_train.cpu().detach()
if self.map_protos:
protos = pl_module.backbone(torch.Tensor(protos)).detach()
with torch.no_grad():
protos = pl_module.backbone(torch.Tensor(protos).to(device))
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
else:
mesh_input, xx, yy = self.get_mesh_input(x_train)
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.Tensor(mesh_input).type_as(_components)
y_pred = pl_module.predict_latent(mesh_input,
map_protos=self.map_protos)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.components_layer._components
y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
@@ -452,4 +285,97 @@ class VisNG2D(Vis2DAbstract):
"k-",
)
self.log_and_display(trainer, pl_module)
class VisSpectralProtos(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.setup_ax()
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
for p, pl in zip(protos, plabels):
ax.plot(p, c=colors[int(pl)])
if self.legend_labels:
handles = get_legend_handles(
colors,
self.legend_labels,
marker="lines",
)
ax.legend(handles=handles)
class VisImgComp(Vis2DAbstract):
def __init__(self,
*args,
random_data=0,
dataformats="CHW",
num_columns=2,
add_embedding=False,
embedding_data=100,
**kwargs):
super().__init__(*args, **kwargs)
self.random_data = random_data
self.dataformats = dataformats
self.num_columns = num_columns
self.add_embedding = add_embedding
self.embedding_data = embedding_data
def on_train_start(self, _, pl_module):
if isinstance(pl_module.logger, TensorBoardLogger):
tb = pl_module.logger.experiment
# Add embedding
if self.add_embedding:
if self.x_train is not None and self.y_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
tag="Data Embedding",
metadata=self.y_train[ind],
metadata_header=None)
else:
raise ValueError("No data for add embedding flag")
# Random Data
if self.random_data:
if self.x_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data = self.x_train[ind]
grid = torchvision.utils.make_grid(data,
nrow=self.num_columns)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=None,
dataformats=self.dataformats)
else:
raise ValueError("No data for random data flag")
else:
warnings.warn(
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.num_columns)
tb.add_image(
tag="Components",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats,
)
def visualize(self, pl_module):
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components,
nrow=self.num_columns)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)

23
prototorch/y/__init__.py Normal file
View File

@@ -0,0 +1,23 @@
from .architectures.base import BaseYArchitecture
from .architectures.comparison import (
OmegaComparisonMixin,
SimpleComparisonMixin,
)
from .architectures.competition import WTACompetitionMixin
from .architectures.components import SupervisedArchitecture
from .architectures.loss import GLVQLossMixin
from .architectures.optimization import (
MultipleLearningRateMixin,
SingleLearningRateMixin,
)
__all__ = [
'BaseYArchitecture',
"OmegaComparisonMixin",
"SimpleComparisonMixin",
"SingleLearningRateMixin",
"MultipleLearningRateMixin",
"SupervisedArchitecture",
"WTACompetitionMixin",
"GLVQLossMixin",
]

View File

@@ -0,0 +1,226 @@
"""
Proto Y Architecture
Network architecture for Component based Learning.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
import pytorch_lightning as pl
import torch
from torchmetrics import Metric
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:
...
# Fields
registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
self.save_hyperparameters(
hparams.__dict__,
ignore=["component_initializer"],
)
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# external API
def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
return self.get_competition(batch, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None:
...
def init_latent(self, hparams: HyperParameters) -> None:
...
def init_comparison(self, hparams: HyperParameters) -> None:
...
def init_competition(self, hparams: HyperParameters) -> None:
...
def init_loss(self, hparams: HyperParameters) -> None:
...
def init_inference(self, hparams: HyperParameters) -> None:
...
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the component set of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparison_measures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparison_measures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparison_measures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
# Y Architecture Hooks
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
def update_metrics_step(self, batch):
# Prediction Metrics
preds = self(batch)
x, y = batch
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for callback in self.registered_metric_callbacks[metric]:
callback(value, self)
instance.reset()
# Lightning Hooks
# Steps
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step([torch.clone(el) for el in batch])
return self.loss_forward(batch)
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@@ -0,0 +1,112 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable, Dict
import torch
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.nn.wrappers import LambdaLayer
from prototorch.y.architectures.base import BaseYArchitecture
from torch import Tensor
from torch.nn.parameter import Parameter
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
**self.comparison_kwargs,
)
return distances
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
A comparison layer that uses the positions of the components and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim: The dimensionality of the latent space. Default: 2.
omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.y.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)

View File

@@ -0,0 +1,64 @@
from dataclasses import dataclass
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.y import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture):
"""
Supervised Architecture
An architecture that uses labeled Components as component Layer.
"""
components_layer: LabeledComponents
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters:
"""
distribution: A valid prototype distribution. No default possible.
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
"""
distribution: "dict[str, int]"
component_initializer: AbstractComponentsInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
if hparams.component_initializer is not None:
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
proto_shape = self.components_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def prototypes(self):
"""
Returns the position of the prototypes.
"""
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
"""
Returns the labels of the prototypes.
"""
return self.components_layer.labels.detach().cpu()

View File

@@ -0,0 +1,42 @@
from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss
from prototorch.y.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture):
"""
GLVQ Loss
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
margin: The margin of the GLVQ loss. Default: 0.0.
transfer_fn: Transfer function to use. Default: sigmoid_beta.
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
"""
margin: float = 0.0
transfer_fn: str = "sigmoid_beta"
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
# Steps
# ----------------------------------------------------------------------------------------------------
def init_loss(self, hparams: HyperParameters):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
**hparams.transfer_args,
)
def loss(self, comparison_measures, batch, components):
target = batch[1]
comp_labels = components[1]
loss = self.loss_layer(comparison_measures, target, comp_labels)
self.log('loss', loss)
return loss

View File

@@ -0,0 +1,73 @@
from dataclasses import dataclass, field
from typing import Type
import torch
from prototorch.y import BaseYArchitecture
from torch.nn.parameter import Parameter
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
"""
Multiple Learning Rates
Define Different Learning Rates for different parameters.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.hparams.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.hparams.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.hparams.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))
return optimizers

218
prototorch/y/callbacks.py Normal file
View File

@@ -0,0 +1,218 @@
import warnings
from typing import Optional, Type
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
from prototorch.y.architectures.base import BaseYArchitecture
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [
'PiYG',
'PRGn',
'BrBG',
'PuOr',
'RdGy',
'RdBu',
'RdYlBu',
'RdYlGn',
'Spectral',
'coolwarm',
'bwr',
'seismic',
]
class LogTorchmetricCallback(pl.Callback):
def __init__(
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
def setup(
self,
trainer: pl.Trainer,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)
class LogConfusionMatrix(LogTorchmetricCallback):
def __init__(
self,
num_classes,
name="confusion",
on='prediction',
**kwargs,
):
super().__init__(
name,
torchmetrics.ConfusionMatrix,
on=on,
num_classes=num_classes,
**kwargs,
)
def __call__(self, value, pl_module: BaseYArchitecture):
fig, ax = plt.subplots()
ax.imshow(value.detach().cpu().numpy())
# Show all ticks and label them with the respective list entries
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(
ax.get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
# Loop over data dimensions and create text annotations.
for i in range(len(value)):
for j in range(len(value)):
text = ax.text(
j,
i,
value[i, j].item(),
ha="center",
va="center",
color="w",
)
ax.set_title(self.name)
fig.tight_layout()
pl_module.logger.experiment.add_figure(
tag=self.name,
figure=fig,
close=True,
global_step=pl_module.global_step,
)
class VisGLVQ2D(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(
np.vstack([x_train, protos]),
self.border,
self.resolution,
)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.components_layer.components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
class PlotLambdaMatrixToTensorboard(pl.Callback):
def __init__(self, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
warnings.warn(
f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
)
def on_train_start(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
self.plot_lambda(trainer, pl_module)
def plot_lambda(self, trainer, pl_module: GMLVQ):
self.fig, self.ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = pl_module.lambda_matrix
# normalize lambda matrix
l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
# plot lambda matrix
self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
self.fig.colorbar(self.ax.images[-1])
# add title
self.ax.set_title('Lambda Matrix')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"lambda_matrix",
self.fig,
trainer.global_step,
)
else:
warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
)

View File

@@ -0,0 +1,7 @@
from .glvq import GLVQ
from .gmlvq import GMLVQ
__all__ = [
"GLVQ",
"GMLVQ",
]

View File

@@ -0,0 +1,35 @@
from dataclasses import dataclass
from prototorch.y import (
SimpleComparisonMixin,
SingleLearningRateMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
from prototorch.y.architectures.loss import GLVQLossMixin
class GLVQ(
SupervisedArchitecture,
SimpleComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
SingleLearningRateMixin,
):
"""
Generalized Learning Vector Quantization (GLVQ)
A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
@dataclass
class HyperParameters(
SimpleComparisonMixin.HyperParameters,
SingleLearningRateMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
No hyperparameters.
"""

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import omega_distance
from prototorch.y import (
GLVQLossMixin,
MultipleLearningRateMixin,
OmegaComparisonMixin,
SupervisedArchitecture,
WTACompetitionMixin,
)
class GMLVQ(
SupervisedArchitecture,
OmegaComparisonMixin,
GLVQLossMixin,
WTACompetitionMixin,
MultipleLearningRateMixin,
):
"""
Generalized Matrix Learning Vector Quantization (GMLVQ)
A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(
MultipleLearningRateMixin.HyperParameters,
OmegaComparisonMixin.HyperParameters,
GLVQLossMixin.HyperParameters,
WTACompetitionMixin.HyperParameters,
SupervisedArchitecture.HyperParameters,
):
"""
comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
"""
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict(
components_layer=0.1,
_omega=0.5,
))

23
setup.cfg Normal file
View File

@@ -0,0 +1,23 @@
[yapf]
based_on_style = pep8
spaces_before_comment = 2
split_before_logical_operator = true
[pylint]
disable =
too-many-arguments,
too-few-public-methods,
fixme,
[pycodestyle]
max-line-length = 79
[isort]
profile = hug
src_paths = isort, test
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 3
use_parentheses = True
line_length = 79

View File

@@ -1,10 +1,12 @@
"""
_____ _ _______ _
| __ \ | | |__ __| | |
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|Plugin
######
# # ##### #### ##### #### ##### #### ##### #### # #
# # # # # # # # # # # # # # # # # #
###### # # # # # # # # # # # # # ######
# ##### # # # # # # # # ##### # # #
# # # # # # # # # # # # # # # # #
# # # #### # #### # #### # # #### # #Plugin
ProtoTorch models Plugin Package
"""
@@ -19,23 +21,51 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"]
DEV = ["bumpversion"]
EXAMPLES = ["matplotlib", "scikit-learn"]
TESTS = ["codecov", "pytest"]
ALL = DEV + EXAMPLES + TESTS
INSTALL_REQUIRES = [
"prototorch>=0.7.3",
"pytorch_lightning>=1.6.0",
"torchmetrics",
"protobuf<3.20.0",
]
CLI = [
"jsonargparse",
]
DEV = [
"bumpversion",
"pre-commit",
]
DOCS = [
"recommonmark",
"sphinx",
"nbsphinx",
"ipykernel",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinxcontrib-bibtex",
]
EXAMPLES = [
"matplotlib",
"scikit-learn",
]
TESTS = [
"codecov",
"pytest",
]
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.1.0",
version="1.0.0-a4",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
long_description_content_type="text/markdown",
author="Alexander Engelsberger",
author_email="engelsbe@hs-mittweida.de",
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.7",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
@@ -51,10 +81,11 @@ setup(
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",

0
tests/__init__.py Normal file
View File

35
tests/test_examples.sh Executable file
View File

@@ -0,0 +1,35 @@
#! /bin/bash
# Read Flags
gpu=0
while [ -n "$1" ]; do
case "$1" in
--gpu) gpu=1;;
-g) gpu=1;;
*) path=$1;;
esac
shift
done
python --version
echo "Using GPU: " $gpu
# Loop
failed=0
for example in $(find $path -maxdepth 1 -name "*.py")
do
echo -n "$x" $example '... '
export DISPLAY= && python $example --fast_dev_run 1 --gpus $gpu &> run_log.txt
if [[ $? -ne 0 ]]; then
echo "FAILED!!"
cat run_log.txt
failed=1
else
echo "SUCCESS!"
fi
rm run_log.txt
done
exit $failed

195
tests/test_models.py Normal file
View File

@@ -0,0 +1,195 @@
"""prototorch.models test suite."""
import prototorch as pt
import pytest
import torch
def test_glvq_model_build():
model = pt.models.GLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq1_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq21_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gmlvq_model_build():
model = pt.models.GMLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_grlvq_model_build():
model = pt.models.GRLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gtlvq_model_build():
model = pt.models.GTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lgmlvq_model_build():
model = pt.models.LGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_image_glvq_model_build():
model = pt.models.ImageGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gmlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gtlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_siamese_glvq_model_build():
model = pt.models.SiameseGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gmlvq_model_build():
model = pt.models.SiameseGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gtlvq_model_build():
model = pt.models.SiameseGTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_knn_model_build():
train_ds = pt.datasets.Iris(dims=[0, 2])
model = pt.models.KNN(dict(k=3), data=train_ds)
def test_lvq1_model_build():
model = pt.models.LVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lvq21_model_build():
model = pt.models.LVQ21(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_median_lvq_model_build():
model = pt.models.MedianLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_celvq_model_build():
model = pt.models.CELVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_rslvq_model_build():
model = pt.models.RSLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_slvq_model_build():
model = pt.models.SLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_growing_neural_gas_model_build():
model = pt.models.GrowingNeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_kohonen_som_model_build():
model = pt.models.KohonenSOM(
{"shape": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_neural_gas_model_build():
model = pt.models.NeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)