docs: update tutorial notebook
This commit is contained in:
		@@ -21,7 +21,7 @@
 | 
				
			|||||||
   "id": "43b74278",
 | 
					   "id": "43b74278",
 | 
				
			||||||
   "metadata": {},
 | 
					   "metadata": {},
 | 
				
			||||||
   "source": [
 | 
					   "source": [
 | 
				
			||||||
    "This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
 | 
					    "This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
    "[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
 | 
					    "[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
@@ -404,6 +404,55 @@
 | 
				
			|||||||
    "## Advanced"
 | 
					    "## Advanced"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "id": "53a64063",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "### Warm-start a model with prototypes learned from another model"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "id": "3177c277",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
 | 
				
			||||||
 | 
					    "model = pt.models.SiameseGMLVQ(\n",
 | 
				
			||||||
 | 
					    "    dict(input_dim=2,\n",
 | 
				
			||||||
 | 
					    "         output_dim=2,\n",
 | 
				
			||||||
 | 
					    "         distribution=(3, 2),\n",
 | 
				
			||||||
 | 
					    "         proto_lr=0.0001,\n",
 | 
				
			||||||
 | 
					    "         bb_lr=0.0001),\n",
 | 
				
			||||||
 | 
					    "    optimizer=torch.optim.Adam,\n",
 | 
				
			||||||
 | 
					    "    prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
 | 
				
			||||||
 | 
					    "    labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
 | 
				
			||||||
 | 
					    "    omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])),  # permute axes\n",
 | 
				
			||||||
 | 
					    ")"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "id": "8baee9a2",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "print(model)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "id": "cc203088",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
   "cell_type": "markdown",
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
   "id": "1f6a33a5",
 | 
					   "id": "1f6a33a5",
 | 
				
			||||||
@@ -423,7 +472,8 @@
 | 
				
			|||||||
    "import pytorch_lightning as pl\n",
 | 
					    "import pytorch_lightning as pl\n",
 | 
				
			||||||
    "import torch\n",
 | 
					    "import torch\n",
 | 
				
			||||||
    "from torchvision import transforms\n",
 | 
					    "from torchvision import transforms\n",
 | 
				
			||||||
    "from torchvision.datasets import MNIST"
 | 
					    "from torchvision.datasets import MNIST\n",
 | 
				
			||||||
 | 
					    "from torchvision.utils import make_grid"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
@@ -484,7 +534,7 @@
 | 
				
			|||||||
   "outputs": [],
 | 
					   "outputs": [],
 | 
				
			||||||
   "source": [
 | 
					   "source": [
 | 
				
			||||||
    "model = pt.models.ImageGLVQ(\n",
 | 
					    "model = pt.models.ImageGLVQ(\n",
 | 
				
			||||||
    "    dict(distribution=(10, 5)),\n",
 | 
					    "    dict(distribution=(10, 1)),\n",
 | 
				
			||||||
    "    prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
 | 
					    "    prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
 | 
				
			||||||
    ")"
 | 
					    ")"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
@@ -496,7 +546,30 @@
 | 
				
			|||||||
   "metadata": {},
 | 
					   "metadata": {},
 | 
				
			||||||
   "outputs": [],
 | 
					   "outputs": [],
 | 
				
			||||||
   "source": [
 | 
					   "source": [
 | 
				
			||||||
    "plt.imshow(model.get_prototype_grid(num_columns=10))"
 | 
					    "plt.imshow(model.get_prototype_grid(num_columns=5))"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "markdown",
 | 
				
			||||||
 | 
					   "id": "1c23c7b2",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
 | 
				
			||||||
 | 
					   ]
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					   "cell_type": "code",
 | 
				
			||||||
 | 
					   "execution_count": null,
 | 
				
			||||||
 | 
					   "id": "30780927",
 | 
				
			||||||
 | 
					   "metadata": {},
 | 
				
			||||||
 | 
					   "outputs": [],
 | 
				
			||||||
 | 
					   "source": [
 | 
				
			||||||
 | 
					    "protos, plabels = pt.components.LabeledComponents(\n",
 | 
				
			||||||
 | 
					    "    distribution=(10, 5),\n",
 | 
				
			||||||
 | 
					    "    components_initializer=pt.initializers.SMCI(init_ds),\n",
 | 
				
			||||||
 | 
					    "    labels_initializer=pt.initializers.LabelsInitializer(),\n",
 | 
				
			||||||
 | 
					    ")()\n",
 | 
				
			||||||
 | 
					    "plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
@@ -564,7 +637,7 @@
 | 
				
			|||||||
   "name": "python",
 | 
					   "name": "python",
 | 
				
			||||||
   "nbconvert_exporter": "python",
 | 
					   "nbconvert_exporter": "python",
 | 
				
			||||||
   "pygments_lexer": "ipython3",
 | 
					   "pygments_lexer": "ipython3",
 | 
				
			||||||
   "version": "3.9.9"
 | 
					   "version": "3.9.12"
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 },
 | 
					 },
 | 
				
			||||||
 "nbformat": 4,
 | 
					 "nbformat": 4,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user