Commit Graph

209 Commits

Author SHA1 Message Date
Alexander Engelsberger
d7834e2cc0
fix: All examples should work on CPU and GPU now 2021-08-05 11:20:02 +02:00
Alexander Engelsberger
0af8cf36f8
fix: labels where on cpu in forward pass 2021-08-05 09:14:32 +02:00
Jensun Ravichandran
f8ad1d83eb
refactor: clean up abstract classes 2021-07-14 19:17:05 +02:00
Jensun Ravichandran
4be9fb81eb
feat(model): implement MedianLVQ 2021-07-06 17:12:51 +02:00
Jensun Ravichandran
9d38123114
refactor: use GLVQLoss instead of LossLayer 2021-07-06 17:09:21 +02:00
Jensun Ravichandran
09e3ef1d0e
fix: remove deprecated Trainer.accelerator_backend 2021-06-30 16:03:45 +02:00
Alexander Engelsberger
7b9b767113 fix: training loss is a zero dimensional tensor
Should fix the problem with EarlyStopping callback.
2021-06-25 17:07:06 +02:00
Jensun Ravichandran
72af03b991 refactor: use LinearTransform instead of torch.nn.Linear 2021-06-25 17:07:06 +02:00
Alexander Engelsberger
71602bf38a
build: bump version 0.1.8 → 0.2.0 2021-06-21 16:47:17 +02:00
Jensun Ravichandran
d42693a441 refactor(api)!: merge the new api changes into dev 2021-06-20 19:00:12 +02:00
Jensun Ravichandran
e5ac50c9a7 Bump version: 0.1.7 → 0.1.8 2021-06-20 17:56:21 +02:00
Jensun Ravichandran
b9eb88a602 Remove .swp file 2021-06-18 13:46:08 +02:00
Jensun Ravichandran
7eb496110f Use mesh2d from prototorch.utils 2021-06-18 13:43:44 +02:00
danielstaps
0a2da9ae50
Added Vis for GMLVQ with more then 2 dims using PCA (#11)
* Added Vis for GMLVQ with more then 2 dims using PCA

* Added initialization possibility to GMlVQ with PCA and one example with omega init + PCA vis of 3 dims

* test(githooks): Add githooks for automatic commit checks

Co-authored-by: staps@hs-mittweida.de <staps@hs-mittweida.de>
Co-authored-by: Alexander Engelsberger <alexanderengelsberger@gmail.com>
2021-06-18 13:28:11 +02:00
Alexander Engelsberger
8956ee75ad
test(githooks): Add githooks for automatic commit checks 2021-06-16 16:16:34 +02:00
Jensun Ravichandran
a37095409b [BUGFIX] examples/cbc_iris.py works again 2021-06-15 15:59:47 +02:00
Jensun Ravichandran
1b420c1f6b [BUG] LVQ1 is broken 2021-06-14 21:08:05 +02:00
Jensun Ravichandran
a44219ee47 [BUG] PLVQ seems broken 2021-06-14 20:56:38 +02:00
Jensun Ravichandran
1c658cdc1b [FEATURE] Add warm-starting example 2021-06-14 20:42:57 +02:00
Jensun Ravichandran
6197d7d5d6 [BUGFIX] examples/dynamic_pruning.py works again 2021-06-14 20:31:39 +02:00
Jensun Ravichandran
d2856383e2 [BUGFIX] examples/gng_iris.py works again 2021-06-14 20:29:31 +02:00
Jensun Ravichandran
3afced8662 [BUGFIX] examples/glvq_spiral.py works again 2021-06-14 20:19:08 +02:00
Jensun Ravichandran
68034d56f6 [BUGFIX] examples/glvq_iris.py works again 2021-06-14 20:13:25 +02:00
Jensun Ravichandran
97ec15b76a [BUGFIX] KNN works again 2021-06-14 20:09:41 +02:00
Jensun Ravichandran
69e5ff3243 Import from the newly cleaned-up prototorch namespace 2021-06-14 20:08:08 +02:00
Alexander Engelsberger
c87ed5ba8b
[FEATURE] Add PLVQ model 2021-06-12 13:02:26 +02:00
Alexander Engelsberger
fc11d78b38
[REFACTOR] Rename LikelihoodRatioLVQ to SLVQ 2021-06-12 13:02:26 +02:00
Jensun Ravichandran
e62a8e6582 [BUGFIX] Log loss in NG and GNG 2021-06-11 18:50:14 +02:00
Jensun Ravichandran
57f8bec270 Update SOM 2021-06-09 18:21:12 +02:00
Jensun Ravichandran
022d791ea5 Route initialized prototypes 2021-06-07 21:18:08 +02:00
Alexander Engelsberger
43fc7d1678 [QA] Remove unused argument from CBC 2021-06-07 21:00:58 +02:00
Jensun Ravichandran
c7b5c88776 [WIP] Add SOM 2021-06-07 18:44:15 +02:00
Jensun Ravichandran
b031382072 Update NG 2021-06-07 18:35:08 +02:00
Jensun Ravichandran
d558fa6a4a [REFACTOR] Clean up GLVQ-types 2021-06-07 17:00:38 +02:00
Jensun Ravichandran
34ffeb95bc Update how the model is printed 2021-06-07 16:58:00 +02:00
Jensun Ravichandran
bed753a6e9 Minor aesthetic change 2021-06-05 01:23:58 +02:00
Jensun Ravichandran
016fcb4060 [REFACTOR] Major cleanup 2021-06-04 22:20:32 +02:00
Jensun Ravichandran
20471bfb1c [FEATURE] Update pruning callback to re-add pruned prototypes 2021-06-04 15:56:46 +02:00
Jensun Ravichandran
42d974e08c No implicit learning rate scheduling 2021-06-04 15:55:06 +02:00
Alexander Engelsberger
47db1965ee [BUGFIX] GNG Example 2021-06-03 15:42:54 +02:00
Alexander Engelsberger
0bc385fe7b [BUGFIX] Neural Gas gets prototype intiailizer from kwargs 2021-06-03 15:24:17 +02:00
Alexander Engelsberger
358f27257d [REFACTOR] Remove prototype_initializer function from GLVQ
Fixes #9
2021-06-03 15:15:22 +02:00
Alexander Engelsberger
bda88149d4 [BUGFIX] Growing neural gas 2021-06-03 15:13:38 +02:00
Alexander Engelsberger
7379c61966 [BUGFIX] Fix image visualization for some parameter combination
image visualization was broken if add_embeding was False, but data visualization was on.
2021-06-03 15:12:51 +02:00
Alexander Engelsberger
e209bf73d5 [BUGFIX] Pruning example works on GPU now 2021-06-03 14:35:24 +02:00
Alexander Engelsberger
1b09b1d57b [BUGFIX] Probabilistic Models work on GPU now 2021-06-03 14:05:44 +02:00
Alexander Engelsberger
459f7c24be [REFACTOR] Probabilistic loss signs changed 2021-06-03 14:00:47 +02:00
Alexander Engelsberger
3b02d99ebe [BUGFIX] Early stopping example works now 2021-06-03 13:38:16 +02:00
Jensun Ravichandran
ef6bcc1079 [BUG] Early stopping does not seem to work
The early stopping callback does not work as expected, and crashes at the end of
max_epochs with:

```
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py in on_train_end(self)
    155         """Called when the train ends."""
    156         for callback in self.callbacks:
--> 157             callback.on_train_end(self, self.lightning_module)
    158
    159     def on_pretrain_routine_start(self) -> None:

~/work/repos/prototorch_models/prototorch/models/callbacks.py in on_train_end(self, trainer, pl_module)
     18     def on_train_end(self, trainer, pl_module):
     19         # instead, do it at the end of training loop
---> 20         self._run_early_stopping_check(trainer, pl_module)
     21
     22

TypeError: _run_early_stopping_check() takes 2 positional arguments but 3 were given
```
2021-06-02 12:44:34 +02:00
Jensun Ravichandran
8851d1bbc9 [FEATURE] Add PruneLoserPrototypes Callback 2021-06-02 03:52:41 +02:00
Jensun Ravichandran
8f7deb75dd [FEATURE] Log prototype win ratios over all training batches 2021-06-02 02:32:54 +02:00
Jensun Ravichandran
7743c50725 Tweak repr 2021-06-02 01:07:48 +02:00
Jensun Ravichandran
fcf3a4979c Remove unused import 2021-06-02 00:49:36 +02:00
Jensun Ravichandran
d46fe4a393 [WIP] Update CBC example 2021-06-02 00:45:33 +02:00
Jensun Ravichandran
88cfd5762e Remove unused imports in models/cbc.py 2021-06-02 00:44:35 +02:00
Jensun Ravichandran
aa42b9e331 [BUGFIX] Import missing module
models/unsupervised.py uses `pt` in line 37, but `pt` is undefined in the file.
I wonder why Python doesn't complain about this. Perhaps because unsupervised.py
is never run in isolation and `pt` is otherwise available in the namespace of
the example scripts that use unsupervised.py.
2021-06-02 00:31:57 +02:00
Jensun Ravichandran
91b57b01b1 [REFACTOR] neighbour -> neighbor 2021-06-02 00:29:45 +02:00
Jensun Ravichandran
98c198d463 [REFACTOR] Use LambdaLayer instead of EuclideanDistance 2021-06-02 00:21:11 +02:00
Jensun Ravichandran
757f4e980d Add Local-Matrix LVQ
Also remove the use of `self.distance_fn` in favor of `self.distance_layer`.
2021-06-01 23:44:16 +02:00
Jensun Ravichandran
5ec2dd47cd Remove unused import 2021-06-01 23:40:56 +02:00
Jensun Ravichandran
e8cd4d765c Remove unused variable 2021-06-01 23:39:39 +02:00
Jensun Ravichandran
8403b01081 Move CELVQ to probabilistic.py 2021-06-01 23:39:06 +02:00
Jensun Ravichandran
aff6aedd60 Use the add_components API for adding prototypes 2021-06-01 23:37:45 +02:00
Jensun Ravichandran
1b6843dbbb Remove unused imports 2021-06-01 19:31:03 +02:00
Jensun Ravichandran
21023a88d7 [BUGFIX] Fix RSLVQ 2021-06-01 17:44:10 +02:00
Alexander Engelsberger
9c1a41997b [FEATURE] Add Growing Neural Gas 2021-06-01 17:19:43 +02:00
Jensun Ravichandran
1636c84778 Rename rslvq example 2021-05-31 17:56:45 +02:00
Jensun Ravichandran
27eccf44d4 Use LambdaLayer from ProtoTorch 2021-05-31 16:53:04 +02:00
Alexander Engelsberger
2a218c0ede Add example for dynamic components in callbacks 2021-05-31 11:39:24 +02:00
Alexander Engelsberger
db064b5af1 Improvement of model __repr__ 2021-05-31 11:19:06 +02:00
Alexander Engelsberger
0ac4ced85d [refactor] Use functional variant of accuracy
Prevents Accuracy in `__repr__` of the models.
2021-05-31 11:12:27 +02:00
Jensun Ravichandran
7b7bc3693d Merge branch 'dev' of github.com:si-cim/prototorch_models into dev 2021-05-31 00:32:49 +02:00
Jensun Ravichandran
cd73f6c427 Add examples/dynamic_components.py 2021-05-31 00:32:27 +02:00
Alexander Engelsberger
a60337ff27 [refactor] Move probabilistic to Prototorch 2021-05-28 20:39:32 +02:00
Alexander Engelsberger
e3392ee952 [refactor] DRY Probabilistic models 2021-05-28 17:13:06 +02:00
Alexander Engelsberger
dade502686 Add MNIST datamodule and training mixin factory. 2021-05-28 16:33:31 +02:00
Jensun Ravichandran
b7edee02c3 [WIP] Add CELVQ
TODO Ensure that the distances/probs corresponding to the plabels are sorted
like the target labels.
2021-05-27 17:40:16 +02:00
Alexander Engelsberger
66e3e51a52 Add references to the documentation. 2021-05-26 21:20:17 +02:00
Jensun Ravichandran
d411e52be4 Refactor non-gradient-lvq models into lvq.py 2021-05-25 20:37:34 +02:00
Alexander Engelsberger
32d6f95db0 Add RSLVQ and LikelihoodLVQ 2021-05-25 20:26:15 +02:00
Jensun Ravichandran
139109804f [BUGFIX] Use _forward in LVQ1 and LVQ21 2021-05-25 17:43:37 +02:00
Alexander Engelsberger
72e064338c Use 'num_' in all variable names 2021-05-25 15:41:10 +02:00
Alexander Engelsberger
8ce18f83ce Add prototype_initializer function to GLVQ
This allows overwriting it inside subclasses.
2021-05-21 17:13:10 +02:00
Alexander Engelsberger
7b4f7d84e0 Update Documentation
Clean up project
2021-05-21 15:42:45 +02:00
Jensun Ravichandran
a5e086ce0d Refactor code 2021-05-21 13:33:57 +02:00
Jensun Ravichandran
0611f81aba Update models namespace 2021-05-21 13:11:59 +02:00
Jensun Ravichandran
a9382dcd9b Add get_prototype_grid method 2021-05-21 13:11:48 +02:00
Jensun Ravichandran
0933a88a1b Fix ImageCBC bug 2021-05-21 13:11:36 +02:00
Jensun Ravichandran
88a34a06ef [WIP] Update CBC implementation to use SiameseGLVQ 2021-05-20 17:36:00 +02:00
Jensun Ravichandran
de63eaf15a Fix numpy issue in vis.py 2021-05-20 17:33:19 +02:00
Jensun Ravichandran
16dc3cf4eb Update image visualization 2021-05-20 16:07:16 +02:00
Jensun Ravichandran
df061cc2ff Refactor code 2021-05-20 14:40:02 +02:00
Alexander Engelsberger
969fb34cc3 Accumulate test loss 2021-05-20 14:20:23 +02:00
Alexander Engelsberger
0204f5eab6 Log test accuracy. 2021-05-20 14:03:31 +02:00
Alexander Engelsberger
b7fc5df386 Log test loss. 2021-05-20 13:47:20 +02:00
Alexander Engelsberger
faf1a88f99 [Bugfix] Remove optimzer_idx from validation and test. 2021-05-20 13:17:27 +02:00
Jensun Ravichandran
5ffbd43a7c Refactor into shared_step 2021-05-19 16:57:51 +02:00
Jensun Ravichandran
fdf9443a2c Add validation and test logic 2021-05-19 16:30:19 +02:00
Jensun Ravichandran
eefec19c9b Custom non-gradient training 2021-05-18 19:49:16 +02:00
Jensun Ravichandran
4957e821f6 Close matplotlib figure on train end 2021-05-18 10:13:22 +02:00