21 Commits

Author SHA1 Message Date
Alexander Engelsberger
d6629c8792 build: bump version 0.4.1 → 0.5.0 2022-04-27 10:28:06 +02:00
Alexander Engelsberger
ef65bd3789 chore: update prototorch dependency 2022-04-27 09:50:48 +02:00
Alexander Engelsberger
d096eba2c9 chore: update pytorch lightning dependency 2022-04-27 09:39:00 +02:00
Alexander Engelsberger
dd34c57e2e ci: fix github action python version 2022-04-27 09:30:07 +02:00
Alexander Engelsberger
5911f4dd90 chore: fix errors for pytorch_lightning>1.6 2022-04-27 09:25:42 +02:00
Alexander Engelsberger
dbfe315f4f ci: add python 3.10 to tests 2022-04-27 09:24:34 +02:00
Jensun Ravichandran
9c90c902dc fix: correct typo 2022-04-04 21:54:04 +02:00
Jensun Ravichandran
7d3f59e54b test: add unit tests 2022-03-30 15:12:33 +02:00
Jensun Ravichandran
9da47b1dba fix: CBC example works again 2022-03-30 15:10:06 +02:00
Alexander Engelsberger
41f0e77fc9 fix: siameseGLVQ checks requires_grad of backbone
Necessary for different optimizer runs
2022-03-29 17:08:40 +02:00
Jensun Ravichandran
fab786a07e fix: rename hparam output_dimlatent_dim in SiameseGMLVQ 2022-03-29 15:24:42 +02:00
Jensun Ravichandran
40bd7ed380 docs: update tutorial notebook 2022-03-29 15:04:05 +02:00
Jensun Ravichandran
4941c2b89d feat: data argument optional in some visualizers 2022-03-29 11:26:22 +02:00
Jensun Ravichandran
ce14dec7e9 feat: add VisSpectralProtos 2022-03-10 15:24:44 +01:00
Jensun Ravichandran
b31c8cc707 feat: add xlabel and ylabel arguments to visualizers 2022-03-09 13:59:19 +01:00
Jensun Ravichandran
e21e6c7e02 docs: update tutorial notebook 2022-02-15 14:38:53 +01:00
Jensun Ravichandran
dd696ea1e0 fix: update hparams.distribution as it changes during training 2022-02-02 21:53:03 +01:00
Jensun Ravichandran
15e7232747 fix: ignore prototype_win_ratios by loading with strict=False 2022-02-02 21:52:01 +01:00
Jensun Ravichandran
197b728c63 feat: add visualize method to visualization callbacks
All visualization callbacks now contain a `visualize` method that takes an
appropriate PyTorchLightning Module and visualizes it without the need for a
Trainer. This is to encourage users to perform one-off visualizations after
training.
2022-02-02 21:45:44 +01:00
Jensun Ravichandran
98892afee0 chore: add example for saving/loading models from checkpoints 2022-02-02 19:02:26 +01:00
Alexander Engelsberger
d5855dbe97 fix: GLVQ can now be restored from checkpoint 2022-02-02 16:17:11 +01:00
18 changed files with 713 additions and 233 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.1
current_version = 0.5.0
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -12,10 +12,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip

View File

@@ -13,10 +13,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
@@ -27,13 +27,15 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
runs-on: ${{ matrix.os }}
steps:
@@ -55,10 +57,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.9"
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.4.1"
release = "0.5.0"
# -- General configuration ---------------------------------------------------

View File

@@ -2,223 +2,252 @@
"cells": [
{
"cell_type": "markdown",
"id": "7ac5eff0",
"metadata": {},
"source": [
"# A short tutorial for the `prototorch.models` plugin"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "beb83780",
"metadata": {},
"source": [
"## Introduction"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "43b74278",
"metadata": {},
"source": [
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
"\n",
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
"\n",
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "4e5d1fad",
"metadata": {},
"source": [
"## Basics"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "1244b66b",
"metadata": {},
"source": [
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dcb88e8a",
"metadata": {},
"outputs": [],
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "1adbe2f8",
"metadata": {},
"source": [
"### Building Models"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "96663ab1",
"metadata": {},
"source": [
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "819ba756",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" hparams=dict(distribution=[1, 1, 1]),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b37e97c",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "d2c86903",
"metadata": {},
"source": [
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
"\n",
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "45806052",
"metadata": {},
"source": [
"### Data"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "9d62c4c6",
"metadata": {},
"source": [
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "504df02c",
"metadata": {},
"outputs": [],
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b8e7756",
"metadata": {},
"outputs": [],
"source": [
"type(train_ds)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bce43afa",
"metadata": {},
"outputs": [],
"source": [
"train_ds.data.shape, train_ds.targets.shape"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "26a83328",
"metadata": {},
"source": [
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67b80fbe",
"metadata": {},
"outputs": [],
"source": [
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1185f31",
"metadata": {},
"outputs": [],
"source": [
"type(train_loader)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b5a8963",
"metadata": {},
"outputs": [],
"source": [
"x_batch, y_batch = next(iter(train_loader))\n",
"print(f\"{x_batch=}, {y_batch=}\")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "dd492ee2",
"metadata": {},
"source": [
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "5176b055",
"metadata": {},
"source": [
"### Training"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "46a7a506",
"metadata": {},
"source": [
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "279e75b7",
"metadata": {},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e496b492",
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(model, train_loader)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "497fbff6",
"metadata": {},
"source": [
"### From data to a trained model - a very minimal example"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab069c5d",
"metadata": {},
"outputs": [],
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
@@ -230,49 +259,239 @@
"\n",
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
"trainer.fit(model, train_loader)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "30c71a93",
"metadata": {},
"source": [
"## Advanced"
],
"metadata": {}
"### Saving/Loading trained models"
]
},
{
"cell_type": "markdown",
"id": "f74ed2c1",
"metadata": {},
"source": [
"### Initializing prototypes with a subset of a dataset (along with transformations)"
],
"metadata": {}
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3156658d",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
"trainer.save_checkpoint(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1c34055",
"metadata": {},
"outputs": [],
"source": [
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
]
},
{
"cell_type": "markdown",
"id": "bbbb08e9",
"metadata": {},
"source": [
"### Visualizing decision boundaries in 2D"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53ca52dc",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
]
},
{
"cell_type": "markdown",
"id": "8373531f",
"metadata": {},
"source": [
"### Saving/Loading trained weights"
]
},
{
"cell_type": "markdown",
"id": "937bc458",
"metadata": {},
"source": [
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f2035af",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
"torch.save(model.state_dict(), ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1206021a",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" dict(distribution=(3, 2)),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f2a4beb",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "528d2fc2",
"metadata": {},
"outputs": [],
"source": [
"torch.load(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec817e6b",
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a208eab7",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "f8de748f",
"metadata": {},
"source": [
"## Advanced"
]
},
{
"cell_type": "markdown",
"id": "53a64063",
"metadata": {},
"source": [
"### Warm-start a model with prototypes learned from another model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3177c277",
"metadata": {},
"outputs": [],
"source": [
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
"model = pt.models.SiameseGMLVQ(\n",
" dict(input_dim=2,\n",
" latent_dim=2,\n",
" distribution=(3, 2),\n",
" proto_lr=0.0001,\n",
" bb_lr=0.0001),\n",
" optimizer=torch.optim.Adam,\n",
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8baee9a2",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc203088",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "1f6a33a5",
"metadata": {},
"source": [
"### Initializing prototypes with a subset of a dataset (along with transformations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "946ce341",
"metadata": {},
"outputs": [],
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"from torchvision import transforms\n",
"from torchvision.datasets import MNIST"
],
"outputs": [],
"metadata": {}
"from torchvision.datasets import MNIST\n",
"from torchvision.utils import make_grid"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "510d9bd4",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea7c1228",
"metadata": {},
"outputs": [],
"source": [
"train_ds = MNIST(\n",
" \"~/datasets\",\n",
@@ -284,59 +503,87 @@
" transforms.ToTensor(),\n",
" ]),\n",
")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b9eaf5c",
"metadata": {},
"outputs": [],
"source": [
"s = int(0.05 * len(train_ds))\n",
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c32c9f2",
"metadata": {},
"outputs": [],
"source": [
"init_ds"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68a9a8b9",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.ImageGLVQ(\n",
" dict(distribution=(10, 5)),\n",
" dict(distribution=(10, 1)),\n",
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"plt.imshow(model.get_prototype_grid(num_columns=10))"
],
"id": "6f23df86",
"metadata": {},
"outputs": [],
"metadata": {}
"source": [
"plt.imshow(model.get_prototype_grid(num_columns=5))"
]
},
{
"cell_type": "markdown",
"id": "1c23c7b2",
"metadata": {},
"source": [
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30780927",
"metadata": {},
"outputs": [],
"source": [
"protos, plabels = pt.components.LabeledComponents(\n",
" distribution=(10, 5),\n",
" components_initializer=pt.initializers.SMCI(init_ds),\n",
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
")()\n",
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
]
},
{
"cell_type": "markdown",
"id": "4fa69f92",
"metadata": {},
"source": [
"## FAQs"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "fa20f9ac",
"metadata": {},
"source": [
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
"\n",
@@ -351,11 +598,12 @@
"```python\n",
">>> model.prototype_labels\n",
"```"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"id": "ba8215bf",
"metadata": {},
"source": [
"### How do I make inferences/predictions/recall with my trained model?\n",
"\n",
@@ -370,13 +618,12 @@
"```python\n",
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
"```"
],
"metadata": {}
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -390,7 +637,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
"version": "3.9.12"
}
},
"nbformat": 4,

View File

@@ -53,3 +53,13 @@ if __name__ == "__main__":
# Training loop
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./glvq_iris.ckpt")
# Load saved model
new_model = pt.models.GLVQ.load_from_checkpoint(
checkpoint_path="./glvq_iris.ckpt",
strict=False,
)
print(new_model)

View File

@@ -26,16 +26,14 @@ from .lvq import (
)
from .probabilistic import (
CELVQ,
PLVQ,
RSLVQ,
SLVQ,
)
from .unsupervised import (
GrowingNeuralGas,
HeskesSOM,
KohonenSOM,
NeuralGas,
)
from .vis import *
__version__ = "0.4.1"
__version__ = "0.5.0"

View File

@@ -7,7 +7,7 @@ import torchmetrics
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.initializers import LabelsInitializer, ZerosCompInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
@@ -43,7 +43,7 @@ class ProtoTorchBolt(pl.LightningModule):
return optimizer
def reconfigure_optimizers(self):
self.trainer.accelerator.setup_optimizers(self.trainer)
self.trainer.strategy.setup_optimizers(self.trainer)
def __repr__(self):
surep = super().__repr__()
@@ -75,10 +75,12 @@ class PrototypeModel(ProtoTorchBolt):
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers()
@@ -107,19 +109,32 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel):
def __init__(self, hparams, **kwargs):
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs)
# Layers
distribution = hparams.get("distribution", None)
prototypes_initializer = kwargs.get("prototypes_initializer", None)
labels_initializer = kwargs.get("labels_initializer",
LabelsInitializer())
if prototypes_initializer is not None:
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
if not skip_proto_layer:
# when subclasses do not need a customized prototype layer
if prototypes_initializer is not None:
# when building a new model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=prototypes_initializer,
labels_initializer=labels_initializer,
)
proto_shape = self.proto_layer.components.shape[1:]
self.hparams.initialized_proto_shape = proto_shape
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams.initialized_proto_shape),
)
self.competition_layer = WTAC()
@property
@@ -177,7 +192,6 @@ class SupervisedPrototypeModel(PrototypeModel):
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):

View File

@@ -137,4 +137,4 @@ class GNGCallback(pl.Callback):
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator.setup_optimizers(trainer)
trainer.strategy.setup_optimizers(trainer)

View File

@@ -15,7 +15,7 @@ class CBC(SiameseGLVQ):
"""Classification-By-Components."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__(hparams, skip_proto_layer=True, **kwargs)
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
components_initializer = kwargs.get("components_initializer", None)

View File

@@ -9,7 +9,7 @@ from ..core.distances import (
omega_distance,
squared_euclidean_distance,
)
from ..core.initializers import EyeTransformInitializer
from ..core.initializers import EyeLinearTransformInitializer
from ..core.losses import (
GLVQLoss,
lvq1_loss,
@@ -39,6 +39,10 @@ class GLVQ(SupervisedPrototypeModel):
beta=self.hparams.transfer_beta,
)
# def on_save_checkpoint(self, checkpoint):
# if "prototype_win_ratios" in checkpoint["state_dict"]:
# del checkpoint["state_dict"]["prototype_win_ratios"]
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
@@ -143,9 +147,13 @@ class SiameseGLVQ(GLVQ):
protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
self.backbone.requires_grad_(bb_grad and self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
self.backbone.requires_grad_(bb_grad)
distances = self.distance_layer(latent_x, latent_protos)
return distances
@@ -223,10 +231,10 @@ class SiameseGMLVQ(SiameseGLVQ):
# Override the backbone
omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer())
EyeLinearTransformInitializer())
self.backbone = LinearTransform(
self.hparams.input_dim,
self.hparams.output_dim,
self.hparams.latent_dim,
initializer=omega_initializer,
)
@@ -255,7 +263,7 @@ class GMLVQ(GLVQ):
# Additional parameters
omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer())
EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
self.register_parameter("_omega", Parameter(omega))

View File

@@ -16,7 +16,7 @@ class KNN(SupervisedPrototypeModel):
"""K-Nearest-Neighbors classification algorithm."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__(hparams, skip_proto_layer=True, **kwargs)
# Default hparams
self.hparams.setdefault("k", 1)
@@ -28,7 +28,7 @@ class KNN(SupervisedPrototypeModel):
# Layers
self.proto_layer = LabeledComponents(
distribution=[],
distribution=len(data) * [1],
components_initializer=LiteralCompInitializer(data),
labels_initializer=LiteralLabelsInitializer(targets))
self.competition_layer = KNNC(k=self.hparams.k)

View File

@@ -67,8 +67,13 @@ class SLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ):
@@ -76,8 +81,13 @@ class RSLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default hparams
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
@@ -88,8 +98,12 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conditional_distribution = RankScaledGaussianPrior(
self.hparams.lambd)
# Default hparams
self.hparams.setdefault("lambda", 1.0)
lam = self.hparams.get("lambda", 1.0)
self.conditional_distribution = RankScaledGaussianPrior(lam)
self.loss = torch.nn.KLDivLoss()
# FIXME

View File

@@ -35,7 +35,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
# Additional parameters
x, y = torch.arange(h), torch.arange(w)
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
self.register_buffer("_grid", grid)
self._sigma = self.hparams.sigma
self._lr = self.hparams.lr
@@ -58,8 +58,10 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
self.proto_layer.load_state_dict(
{"_components": updated_protos},
strict=False,
)
def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
@@ -88,12 +90,12 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("age_limit", 10)
self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit,
agelimit=self.hparams.age_limit,
num_prototypes=self.hparams.num_prototypes,
)
@@ -145,6 +147,8 @@ class GrowingNeuralGas(NeuralGas):
def configure_callbacks(self):
return [
GNGCallback(reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq)
GNGCallback(
reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq,
)
]

View File

@@ -7,15 +7,19 @@ import torchvision
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset
from ..utils.colors import get_colors, get_legend_handles
from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback):
def __init__(self,
data,
data=None,
title="Prototype Visualization",
cmap="viridis",
xlabel="Data dimension 1",
ylabel="Data dimension 2",
legend_labels=None,
border=0.1,
resolution=100,
flatten_data=True,
@@ -28,24 +32,31 @@ class Vis2DAbstract(pl.Callback):
block=False):
super().__init__()
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
if data:
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
else:
x, y = data
if flatten_data:
x = x.reshape(len(x), -1)
self.x_train = x
self.y_train = y
else:
x, y = data
if flatten_data:
x = x.reshape(len(x), -1)
self.x_train = x
self.y_train = y
self.x_train = None
self.y_train = None
self.title = title
self.xlabel = xlabel
self.ylabel = ylabel
self.legend_labels = legend_labels
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
@@ -64,14 +75,12 @@ class Vis2DAbstract(pl.Callback):
return False
return True
def setup_ax(self, xlabel=None, ylabel=None):
def setup_ax(self):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
if xlabel:
ax.set_xlabel("Data dimension 1")
if ylabel:
ax.set_ylabel("Data dimension 2")
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
if self.axis_off:
ax.axis("off")
return ax
@@ -114,33 +123,36 @@ class Vis2DAbstract(pl.Callback):
else:
plt.show(block=self.block)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
self.visualize(pl_module)
self.log_and_display(trainer, pl_module)
def on_train_end(self, trainer, pl_module):
plt.close()
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
ax = self.setup_ax()
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisSiameseGLVQ2D(Vis2DAbstract):
@@ -148,10 +160,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
@@ -178,8 +187,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisGMLVQ2D(Vis2DAbstract):
@@ -187,10 +194,7 @@ class VisGMLVQ2D(Vis2DAbstract):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
@@ -212,19 +216,13 @@ class VisGMLVQ2D(Vis2DAbstract):
if self.show_protos:
self.plot_protos(ax, protos, plabels)
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
@@ -236,21 +234,15 @@ class VisCBC2D(Vis2DAbstract):
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
@@ -264,7 +256,23 @@ class VisNG2D(Vis2DAbstract):
"k-",
)
self.log_and_display(trainer, pl_module)
class VisSpectralProtos(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.setup_ax()
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
for p, pl in zip(protos, plabels):
ax.plot(p, c=colors[int(pl)])
if self.legend_labels:
handles = get_legend_handles(
colors,
self.legend_labels,
marker="lines",
)
ax.legend(handles=handles)
class VisImgComp(Vis2DAbstract):
@@ -321,14 +329,9 @@ class VisImgComp(Vis2DAbstract):
dataformats=self.dataformats,
)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
def visualize(self, pl_module):
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components,
nrow=self.num_columns)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)

View File

@@ -22,8 +22,8 @@ with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = [
"prototorch>=0.7.0",
"pytorch_lightning>=1.3.5",
"prototorch>=0.7.3",
"pytorch_lightning>=1.6.0",
"torchmetrics",
]
CLI = [
@@ -54,7 +54,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.4.1",
version="0.5.0",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
@@ -64,7 +64,7 @@ setup(
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.6",
python_requires=">=3.7",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
@@ -80,10 +80,10 @@ setup(
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.6",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",

View File

@@ -1,15 +0,0 @@
"""prototorch.models test suite."""
import unittest
class TestDummy(unittest.TestCase):
def setUp(self):
pass
def test_dummy(self):
pass
def tearDown(self):
pass

195
tests/test_models.py Normal file
View File

@@ -0,0 +1,195 @@
"""prototorch.models test suite."""
import prototorch as pt
import pytest
import torch
def test_glvq_model_build():
model = pt.models.GLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq1_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_glvq21_model_build():
model = pt.models.GLVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gmlvq_model_build():
model = pt.models.GMLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_grlvq_model_build():
model = pt.models.GRLVQ(
{
"distribution": (3, 2),
"input_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_gtlvq_model_build():
model = pt.models.GTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lgmlvq_model_build():
model = pt.models.LGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_image_glvq_model_build():
model = pt.models.ImageGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gmlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_image_gtlvq_model_build():
model = pt.models.ImageGMLVQ(
{
"distribution": (3, 2),
"input_dim": 16,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(16),
)
def test_siamese_glvq_model_build():
model = pt.models.SiameseGLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gmlvq_model_build():
model = pt.models.SiameseGMLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_siamese_gtlvq_model_build():
model = pt.models.SiameseGTLVQ(
{
"distribution": (3, 2),
"input_dim": 4,
"latent_dim": 2,
},
prototypes_initializer=pt.initializers.RNCI(4),
)
def test_knn_model_build():
train_ds = pt.datasets.Iris(dims=[0, 2])
model = pt.models.KNN(dict(k=3), data=train_ds)
def test_lvq1_model_build():
model = pt.models.LVQ1(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_lvq21_model_build():
model = pt.models.LVQ21(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_median_lvq_model_build():
model = pt.models.MedianLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_celvq_model_build():
model = pt.models.CELVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_rslvq_model_build():
model = pt.models.RSLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_slvq_model_build():
model = pt.models.SLVQ(
{"distribution": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_growing_neural_gas_model_build():
model = pt.models.GrowingNeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_kohonen_som_model_build():
model = pt.models.KohonenSOM(
{"shape": (3, 2)},
prototypes_initializer=pt.initializers.RNCI(2),
)
def test_neural_gas_model_build():
model = pt.models.NeuralGas(
{"num_prototypes": 5},
prototypes_initializer=pt.initializers.RNCI(2),
)