From 40bd7ed380830b3a0dd60d4301272fb65ff9fb6e Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 29 Mar 2022 15:04:05 +0200 Subject: [PATCH] docs: update tutorial notebook --- docs/source/tutorial.ipynb | 83 +++++++++++++++++++++++++++++++++++--- 1 file changed, 78 insertions(+), 5 deletions(-) diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index 2ecfffa..3329848 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -21,7 +21,7 @@ "id": "43b74278", "metadata": {}, "source": [ - "This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n", + "This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n", "\n", "[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n", "\n", @@ -404,6 +404,55 @@ "## Advanced" ] }, + { + "cell_type": "markdown", + "id": "53a64063", + "metadata": {}, + "source": [ + "### Warm-start a model with prototypes learned from another model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3177c277", + "metadata": {}, + "outputs": [], + "source": [ + "trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n", + "model = pt.models.SiameseGMLVQ(\n", + " dict(input_dim=2,\n", + " output_dim=2,\n", + " distribution=(3, 2),\n", + " proto_lr=0.0001,\n", + " bb_lr=0.0001),\n", + " optimizer=torch.optim.Adam,\n", + " prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n", + " labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n", + " omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8baee9a2", + "metadata": {}, + "outputs": [], + "source": [ + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc203088", + "metadata": {}, + "outputs": [], + "source": [ + "pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)" + ] + }, { "cell_type": "markdown", "id": "1f6a33a5", @@ -423,7 +472,8 @@ "import pytorch_lightning as pl\n", "import torch\n", "from torchvision import transforms\n", - "from torchvision.datasets import MNIST" + "from torchvision.datasets import MNIST\n", + "from torchvision.utils import make_grid" ] }, { @@ -484,7 +534,7 @@ "outputs": [], "source": [ "model = pt.models.ImageGLVQ(\n", - " dict(distribution=(10, 5)),\n", + " dict(distribution=(10, 1)),\n", " prototypes_initializer=pt.initializers.SMCI(init_ds),\n", ")" ] @@ -496,7 +546,30 @@ "metadata": {}, "outputs": [], "source": [ - "plt.imshow(model.get_prototype_grid(num_columns=10))" + "plt.imshow(model.get_prototype_grid(num_columns=5))" + ] + }, + { + "cell_type": "markdown", + "id": "1c23c7b2", + "metadata": {}, + "source": [ + "We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30780927", + "metadata": {}, + "outputs": [], + "source": [ + "protos, plabels = pt.components.LabeledComponents(\n", + " distribution=(10, 5),\n", + " components_initializer=pt.initializers.SMCI(init_ds),\n", + " labels_initializer=pt.initializers.LabelsInitializer(),\n", + ")()\n", + "plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")" ] }, { @@ -564,7 +637,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.9" + "version": "3.9.12" } }, "nbformat": 4,