Alexander Engelsberger
e209bf73d5
[BUGFIX] Pruning example works on GPU now
2021-06-03 14:35:24 +02:00
Alexander Engelsberger
1b09b1d57b
[BUGFIX] Probabilistic Models work on GPU now
2021-06-03 14:05:44 +02:00
Alexander Engelsberger
459f7c24be
[REFACTOR] Probabilistic loss signs changed
2021-06-03 14:00:47 +02:00
Alexander Engelsberger
3b02d99ebe
[BUGFIX] Early stopping example works now
2021-06-03 13:38:16 +02:00
Jensun Ravichandran
ef6bcc1079
[BUG] Early stopping does not seem to work
...
The early stopping callback does not work as expected, and crashes at the end of
max_epochs with:
```
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py in on_train_end(self)
155 """Called when the train ends."""
156 for callback in self.callbacks:
--> 157 callback.on_train_end(self, self.lightning_module)
158
159 def on_pretrain_routine_start(self) -> None:
~/work/repos/prototorch_models/prototorch/models/callbacks.py in on_train_end(self, trainer, pl_module)
18 def on_train_end(self, trainer, pl_module):
19 # instead, do it at the end of training loop
---> 20 self._run_early_stopping_check(trainer, pl_module)
21
22
TypeError: _run_early_stopping_check() takes 2 positional arguments but 3 were given
```
2021-06-02 12:44:34 +02:00
Jensun Ravichandran
8851d1bbc9
[FEATURE] Add PruneLoserPrototypes Callback
2021-06-02 03:52:41 +02:00
Jensun Ravichandran
8f7deb75dd
[FEATURE] Log prototype win ratios over all training batches
2021-06-02 02:32:54 +02:00
Jensun Ravichandran
7743c50725
Tweak repr
2021-06-02 01:07:48 +02:00
Jensun Ravichandran
fcf3a4979c
Remove unused import
2021-06-02 00:49:36 +02:00
Jensun Ravichandran
d46fe4a393
[WIP] Update CBC example
2021-06-02 00:45:33 +02:00
Jensun Ravichandran
88cfd5762e
Remove unused imports in models/cbc.py
2021-06-02 00:44:35 +02:00
Jensun Ravichandran
aa42b9e331
[BUGFIX] Import missing module
...
models/unsupervised.py uses `pt` in line 37, but `pt` is undefined in the file.
I wonder why Python doesn't complain about this. Perhaps because unsupervised.py
is never run in isolation and `pt` is otherwise available in the namespace of
the example scripts that use unsupervised.py.
2021-06-02 00:31:57 +02:00
Jensun Ravichandran
91b57b01b1
[REFACTOR] neighbour
-> neighbor
2021-06-02 00:29:45 +02:00
Jensun Ravichandran
98c198d463
[REFACTOR] Use LambdaLayer instead of EuclideanDistance
2021-06-02 00:21:11 +02:00
Jensun Ravichandran
757f4e980d
Add Local-Matrix LVQ
...
Also remove the use of `self.distance_fn` in favor of `self.distance_layer`.
2021-06-01 23:44:16 +02:00
Jensun Ravichandran
5ec2dd47cd
Remove unused import
2021-06-01 23:40:56 +02:00
Jensun Ravichandran
e8cd4d765c
Remove unused variable
2021-06-01 23:39:39 +02:00
Jensun Ravichandran
8403b01081
Move CELVQ to probabilistic.py
2021-06-01 23:39:06 +02:00
Jensun Ravichandran
aff6aedd60
Use the add_components
API for adding prototypes
2021-06-01 23:37:45 +02:00
Jensun Ravichandran
1b6843dbbb
Remove unused imports
2021-06-01 19:31:03 +02:00
Jensun Ravichandran
21023a88d7
[BUGFIX] Fix RSLVQ
2021-06-01 17:44:10 +02:00
Alexander Engelsberger
9c1a41997b
[FEATURE] Add Growing Neural Gas
2021-06-01 17:19:43 +02:00
Jensun Ravichandran
1636c84778
Rename rslvq example
2021-05-31 17:56:45 +02:00
Jensun Ravichandran
27eccf44d4
Use LambdaLayer from ProtoTorch
2021-05-31 16:53:04 +02:00
Alexander Engelsberger
2a218c0ede
Add example for dynamic components in callbacks
2021-05-31 11:39:24 +02:00
Alexander Engelsberger
db064b5af1
Improvement of model __repr__
2021-05-31 11:19:06 +02:00
Alexander Engelsberger
0ac4ced85d
[refactor] Use functional variant of accuracy
...
Prevents Accuracy in `__repr__` of the models.
2021-05-31 11:12:27 +02:00
Jensun Ravichandran
7b7bc3693d
Merge branch 'dev' of github.com:si-cim/prototorch_models into dev
2021-05-31 00:32:49 +02:00
Jensun Ravichandran
cd73f6c427
Add examples/dynamic_components.py
2021-05-31 00:32:27 +02:00
Alexander Engelsberger
a60337ff27
[refactor] Move probabilistic to Prototorch
2021-05-28 20:39:32 +02:00
Alexander Engelsberger
e3392ee952
[refactor] DRY Probabilistic models
2021-05-28 17:13:06 +02:00
Alexander Engelsberger
dade502686
Add MNIST datamodule and training mixin factory.
2021-05-28 16:33:31 +02:00
Jensun Ravichandran
b7edee02c3
[WIP] Add CELVQ
...
TODO Ensure that the distances/probs corresponding to the plabels are sorted
like the target labels.
2021-05-27 17:40:16 +02:00
Alexander Engelsberger
66e3e51a52
Add references to the documentation.
2021-05-26 21:20:17 +02:00
Jensun Ravichandran
d411e52be4
Refactor non-gradient-lvq models into lvq.py
2021-05-25 20:37:34 +02:00
Alexander Engelsberger
32d6f95db0
Add RSLVQ and LikelihoodLVQ
2021-05-25 20:26:15 +02:00
Jensun Ravichandran
139109804f
[BUGFIX] Use _forward
in LVQ1 and LVQ21
2021-05-25 17:43:37 +02:00
Alexander Engelsberger
72e064338c
Use 'num_' in all variable names
2021-05-25 15:41:10 +02:00
Alexander Engelsberger
8ce18f83ce
Add prototype_initializer function to GLVQ
...
This allows overwriting it inside subclasses.
2021-05-21 17:13:10 +02:00
Alexander Engelsberger
7b4f7d84e0
Update Documentation
...
Clean up project
2021-05-21 15:42:45 +02:00
Jensun Ravichandran
a5e086ce0d
Refactor code
2021-05-21 13:33:57 +02:00
Jensun Ravichandran
0611f81aba
Update models namespace
2021-05-21 13:11:59 +02:00
Jensun Ravichandran
a9382dcd9b
Add get_prototype_grid method
2021-05-21 13:11:48 +02:00
Jensun Ravichandran
0933a88a1b
Fix ImageCBC bug
2021-05-21 13:11:36 +02:00
Jensun Ravichandran
88a34a06ef
[WIP] Update CBC implementation to use SiameseGLVQ
2021-05-20 17:36:00 +02:00
Jensun Ravichandran
de63eaf15a
Fix numpy issue in vis.py
2021-05-20 17:33:19 +02:00
Jensun Ravichandran
16dc3cf4eb
Update image visualization
2021-05-20 16:07:16 +02:00
Jensun Ravichandran
df061cc2ff
Refactor code
2021-05-20 14:40:02 +02:00
Alexander Engelsberger
969fb34cc3
Accumulate test loss
2021-05-20 14:20:23 +02:00
Alexander Engelsberger
0204f5eab6
Log test accuracy.
2021-05-20 14:03:31 +02:00
Alexander Engelsberger
b7fc5df386
Log test loss.
2021-05-20 13:47:20 +02:00
Alexander Engelsberger
faf1a88f99
[Bugfix] Remove optimzer_idx from validation and test.
2021-05-20 13:17:27 +02:00
Jensun Ravichandran
5ffbd43a7c
Refactor into shared_step
2021-05-19 16:57:51 +02:00
Jensun Ravichandran
fdf9443a2c
Add validation and test logic
2021-05-19 16:30:19 +02:00
Jensun Ravichandran
eefec19c9b
Custom non-gradient training
2021-05-18 19:49:16 +02:00
Jensun Ravichandran
4957e821f6
Close matplotlib figure on train end
2021-05-18 10:13:22 +02:00
Jensun Ravichandran
81346785bd
Cleanup models
...
Siamese architectures no longer accept a `backbone_module`. They have to be
initialized with an pre-initialized backbone object instead. This is so that the
visualization callbacks could use the very same object for visualization
purposes. Also, there's no longer a dependent copy of the backbone. It is
managed simply with `requires_grad` instead.
2021-05-17 17:00:23 +02:00
Jensun Ravichandran
7a87636ad7
Update KNN
2021-05-17 16:59:35 +02:00
Jensun Ravichandran
77b7b59bad
Clean visualization callbacks
2021-05-17 16:59:22 +02:00
Jensun Ravichandran
6e7d80be88
[BUGFIX] Fix siamese visualization callback
2021-05-15 12:52:44 +02:00
Jensun Ravichandran
b7684ae512
predict_latent
no longer returns numpy
2021-05-15 12:52:16 +02:00
Alexander Engelsberger
0eac2ce326
Examples use GPUs if available.
2021-05-13 15:22:01 +02:00
Jensun Ravichandran
8f9c29bd2b
[BUGFIX] Remove incorrect import statement
2021-05-12 16:45:01 +02:00
Jensun Ravichandran
ca39aa00d5
Stop passing component initializers as hparams
...
Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.
The example scripts have also been updated to reflect the new changes.
Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
2021-05-12 16:36:22 +02:00
Alexander Engelsberger
1498c4bde5
Bump version: 0.1.6 → 0.1.7
2021-05-11 17:18:29 +02:00
Jensun Ravichandran
59b8ab6643
Add knn
2021-05-11 17:22:02 +02:00
Jensun Ravichandran
eab1ec72c2
Change optimizer using kwargs
2021-05-11 16:13:00 +02:00
Jensun Ravichandran
b38acd58a8
[BUGFIX] Fix visualization callbacks bug
2021-05-11 16:09:27 +02:00
Alexander Engelsberger
e87563e10d
Bump version: 0.1.5 → 0.1.6
2021-05-11 13:41:26 +02:00
Alexander Engelsberger
3fa6378c4d
Add LVQ1 and LVQ2.1 Models.
2021-05-11 13:26:13 +02:00
Alexander Engelsberger
30ee287ecc
Bump version: 0.1.4 → 0.1.5
2021-05-10 17:13:00 +02:00
Alexander Engelsberger
f49db0bf2c
Bump version: 0.1.3 → 0.1.4
2021-05-10 17:06:28 +02:00
Alexander Engelsberger
54a8494d86
Bump version: 0.1.2 → 0.1.3
2021-05-10 17:04:20 +02:00
Alexander Engelsberger
bf310be97c
Bump version: 0.1.1 → 0.1.2
2021-05-10 16:47:33 +02:00
Alexander Engelsberger
1ae2b41edd
Bump version: 0.1.0 → 0.1.1
2021-05-10 16:26:21 +02:00
Alexander Engelsberger
f6e3a37e2b
Bump version: 0.0.0 → 0.1.0
2021-05-10 16:01:46 +02:00
Alexander Engelsberger
6873927349
Setup bumpversion
2021-05-10 15:34:43 +02:00
Jensun Ravichandran
49100f43f5
Example to save and reload a model
2021-05-10 14:30:02 +02:00
Jensun Ravichandran
ed03ab168e
[BUGFIX] Fix lambda_matrix property in GMLVQ
2021-05-10 14:09:25 +02:00
Jensun Ravichandran
c6e06ceaa4
Properly initialize prototypes in LVQMLN
2021-05-09 20:55:28 +02:00
Jensun Ravichandran
ca4c9da10a
Add the namespace hook for GMLVQ in the model class
2021-05-09 20:53:31 +02:00
Jensun Ravichandran
ff7a1e93d2
Refactor visualization callbacks
2021-05-09 20:53:03 +02:00
Jensun Ravichandran
11b3e53ecb
Return prototypes as torch tensor
2021-05-07 15:45:37 +02:00
Jensun Ravichandran
d7972a69e8
Update GMLVQ model
2021-05-07 15:24:47 +02:00
Jensun Ravichandran
17315ff242
Add models to the prototorch.models namespace
2021-05-07 15:23:52 +02:00
Jensun Ravichandran
5f937066bf
Move and improve visualization callbacks
2021-05-07 15:22:54 +02:00
Alexander Engelsberger
4bbe73e3a9
Add GRLVQ with examples.
2021-05-06 18:42:06 +02:00
Alexander Engelsberger
3df282a0af
Increase visualization pause.
2021-05-06 18:41:33 +02:00
Alexander Engelsberger
5a2f4f6170
Revert deletion of training accuracy.
2021-05-06 18:02:01 +02:00
Alexander Engelsberger
1c3613019b
Update Examples to new initializer architecture.
...
Visualization still borken for some examples.
2021-05-06 14:10:09 +02:00
Jensun Ravichandran
d644114090
Add loss transfer function to glvq
2021-05-04 20:56:16 +02:00
Jensun Ravichandran
f402eea884
Add GMLVQ examples
2021-05-04 15:11:16 +02:00
Jensun Ravichandran
a1ac5a70c7
Use squared euclidean distance in GMLVQ
2021-05-04 14:34:00 +02:00
Jensun Ravichandran
d8e017ae74
Update SiameseGLVQ
2021-05-03 16:09:22 +02:00
Jensun Ravichandran
96aeaa3448
Add support for multiple optimizers
2021-05-03 13:20:49 +02:00
Jensun Ravichandran
042b3fcaa2
Add tensorboard argument to visualization callbacks
2021-05-03 13:19:23 +02:00
Jensun Ravichandran
6dd9b1492c
Add more models
2021-04-29 23:37:22 +02:00
Jensun Ravichandran
db7bb7619f
Add border argument in visualization callback
2021-04-29 22:36:10 +02:00
Jensun Ravichandran
ccaa52c408
Add missing abstract.py file
2021-04-29 19:14:33 +02:00
Jensun Ravichandran
fef73e2fbf
[BUG] NaN when training with selection initializer
...
How to reproduce:
Run the `glvq_spiral.py` file under `examples/`.
The error seems to occur when using a lot of prototypes in combination with the
`StratifiedSelectionInitializer`. Using only a prototype per class, or using
another initializer like the `StratifiedMeanInitializer` seems to make the
problem go away.
2021-04-29 19:09:10 +02:00
Jensun Ravichandran
a16bebd0c4
Use Components instead of Prototypes and refactor old examples
2021-04-29 17:05:41 +02:00
Alexander Engelsberger
eeb684b3b6
GLVQ with configurable distance.
2021-04-27 15:41:44 +02:00
Jensun Ravichandran
1fb197077c
Add siamese glvq
2021-04-27 14:35:17 +02:00
Alexander Engelsberger
466bbe4c63
Add Neural Gas Model.
2021-04-23 17:30:23 +02:00
Alexander Engelsberger
fd12b18073
Add visualization callback from Protoflow.
2021-04-23 17:28:03 +02:00
Alexander Engelsberger
c4c51a16fe
Automatic Formating.
2021-04-23 17:27:47 +02:00
Alexander Engelsberger
db4499a103
Add more CBC examples. MNIST is broken.
2021-04-22 17:37:20 +02:00
Jensun Ravichandran
2e2f6707f6
Add partial cbc implementation
2021-04-22 16:01:44 +02:00
Jensun Ravichandran
d0d69f610e
Show accuracy in the progress bar
2021-04-21 22:28:36 +02:00
Jensun Ravichandran
fadf8c25bf
Add more experimental changes
...
The code gets very messy very quickly as soon as serialization features are
needed.
2021-04-21 21:59:19 +02:00
Jensun Ravichandran
e5a62bd0fc
Fix broken state from previous commit
2021-04-21 21:35:52 +02:00
Jensun Ravichandran
fe36e5fad9
Add partial metric/hparam features [BROKEN STATE]
2021-04-21 19:16:57 +02:00
Jensun Ravichandran
f6994dfd83
Add glvq model
2021-04-21 14:51:34 +02:00
Jensun Ravichandran
f4e703abee
Use models namespace
2021-04-21 13:30:50 +02:00
Jensun Ravichandran
c8b7ea2e97
Initial commit
2021-04-21 13:13:28 +02:00