prototorch_models/docs/source/tutorial.ipynb

646 lines
18 KiB
Plaintext
Raw Normal View History

2021-05-18 17:41:58 +00:00
{
"cells": [
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "7ac5eff0",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"# A short tutorial for the `prototorch.models` plugin"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "beb83780",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"## Introduction"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "43b74278",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
2022-03-29 13:04:05 +00:00
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
2021-05-18 17:41:58 +00:00
"\n",
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
"\n",
"With the [prototorch.models](https://github.com/si-cim/prototorch_models) plugin, we've gone one step further and pre-packaged commonly used prototype-models like GMLVQ as [Lightning-Modules](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html?highlight=lightning%20module#pytorch_lightning.core.lightning.LightningModule). With only a few lines to code, it is now possible to build and train prototype-models. It quite simply cannot get any simpler than this."
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "4e5d1fad",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"## Basics"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "1244b66b",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"First things first. When working with the models plugin, you'll probably need `torch`, `prototorch` and `pytorch_lightning`. So, we recommend that you import all three like so:"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "dcb88e8a",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "1adbe2f8",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"### Building Models"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "96663ab1",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"Let's start by building a `GLVQ` model. It is one of the simplest models to build. The only requirements are a prototype distribution and an initializer."
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "819ba756",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"model = pt.models.GLVQ(\n",
" hparams=dict(distribution=[1, 1, 1]),\n",
2021-07-12 19:21:29 +00:00
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
2021-05-18 17:41:58 +00:00
")"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "1b37e97c",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"print(model)"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "d2c86903",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
2021-07-12 19:21:29 +00:00
"The key `distribution` in the `hparams` argument describes the prototype distribution. If it is a Python [list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed that there are as many entries in this list as there are classes, and the number at each location of this list describes the number of prototypes to be used for that particular class. So, `[1, 1, 1]` implies that we have three classes with one prototype per class. If it is a Python [tuple](https://docs.python.org/3/tutorial/datastructures.html), a shorthand of `(num_classes, prototypes_per_class)` is assumed. If it is a Python [dictionary](https://docs.python.org/3/tutorial/datastructures.html), the key-value pairs describe the class label and the number of prototypes for that class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes with labels `{1, 2, 3}`, each equipped with two prototypes. If however, the dictionary contains the keys `\"num_classes\"` and `\"per_class\"`, they are parsed to use their values as one might expect.\n",
2021-05-25 18:54:07 +00:00
"\n",
2021-07-12 19:21:29 +00:00
"The `prototypes_initializer` argument describes how the prototypes are meant to be initialized. This argument has to be an instantiated object of some kind of [AbstractComponentsInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L18). If this is a [ShapeAwareCompInitializer](https://github.com/si-cim/prototorch/blob/dev/prototorch/components/initializers.py#L41), this only requires a `shape` arugment that describes the shape of the prototypes. So, `pt.initializers.ZerosCompInitializer(3)` creates 3d-vector prototypes all initialized to zeros."
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "45806052",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"### Data"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "9d62c4c6",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"The preferred way to working with data in `torch` is to use the [Dataset and Dataloader API](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). There a few pre-packaged datasets available under `prototorch.datasets`. See [here](https://prototorch.readthedocs.io/en/latest/api.html#module-prototorch.datasets) for a full list of available datasets."
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "504df02c",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "3b8e7756",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"type(train_ds)"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "bce43afa",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"train_ds.data.shape, train_ds.targets.shape"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "26a83328",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"Once we have such a dataset, we could wrap it in a `Dataloader` to load the data in batches, and possibly apply some transformations on the fly."
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "67b80fbe",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2)"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "c1185f31",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"type(train_loader)"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "9b5a8963",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"x_batch, y_batch = next(iter(train_loader))\n",
"print(f\"{x_batch=}, {y_batch=}\")"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "dd492ee2",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"This perhaps seems like a lot of work for a small dataset that fits completely in memory. However, this comes in very handy when dealing with huge datasets that can only be processed in batches."
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "5176b055",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"### Training"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "46a7a506",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"If you're familiar with other deep learning frameworks, you might perhaps expect a `.fit(...)` or `.train(...)` method. However, in PyTorch-Lightning, this is done slightly differently. We first create a trainer and then pass the model and the Dataloader to `trainer.fit(...)` instead. So, it is more functional in style than object-oriented."
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "279e75b7",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"trainer = pl.Trainer(max_epochs=2, weights_summary=None)"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "e496b492",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"trainer.fit(model, train_loader)"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "497fbff6",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"### From data to a trained model - a very minimal example"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "ab069c5d",
"metadata": {},
"outputs": [],
2021-05-18 17:41:58 +00:00
"source": [
"train_ds = pt.datasets.Iris(dims=[0, 2])\n",
"train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)\n",
"\n",
"model = pt.models.GLVQ(\n",
" dict(distribution=(3, 2), lr=0.1),\n",
2021-07-12 19:21:29 +00:00
" prototypes_initializer=pt.initializers.SMCI(train_ds),\n",
2021-05-18 17:41:58 +00:00
")\n",
"\n",
"trainer = pl.Trainer(max_epochs=50, weights_summary=None)\n",
"trainer.fit(model, train_loader)"
2022-02-15 13:38:53 +00:00
]
},
{
"cell_type": "markdown",
"id": "30c71a93",
"metadata": {},
"source": [
"### Saving/Loading trained models"
]
},
{
"cell_type": "markdown",
"id": "f74ed2c1",
"metadata": {},
"source": [
"Pytorch Lightning can automatically checkpoint the model during various stages of training, but it also possible to manually save a checkpoint after training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3156658d",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris.ckpt\"\n",
"trainer.save_checkpoint(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1c34055",
"metadata": {},
"outputs": [],
"source": [
"loaded_model = pt.models.GLVQ.load_from_checkpoint(ckpt_path, strict=False)"
]
},
{
"cell_type": "markdown",
"id": "bbbb08e9",
"metadata": {},
"source": [
"### Visualizing decision boundaries in 2D"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53ca52dc",
"metadata": {},
2021-07-12 19:21:29 +00:00
"outputs": [],
2022-02-15 13:38:53 +00:00
"source": [
"pt.models.VisGLVQ2D(data=train_ds).visualize(loaded_model)"
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "8373531f",
"metadata": {},
"source": [
"### Saving/Loading trained weights"
]
},
{
"cell_type": "markdown",
"id": "937bc458",
"metadata": {},
"source": [
"In most cases, the checkpointing workflow is sufficient. In some cases however, one might want to only save the trained weights from the model. The disadvantage of this method is that the model has be re-created using compatible initialization parameters before the weights could be loaded."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f2035af",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = \"./checkpoints/glvq_iris_weights.pth\"\n",
"torch.save(model.state_dict(), ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1206021a",
"metadata": {},
"outputs": [],
"source": [
"model = pt.models.GLVQ(\n",
" dict(distribution=(3, 2)),\n",
" prototypes_initializer=pt.initializers.ZerosCompInitializer(2),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f2a4beb",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"Before loading the weights\").visualize(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "528d2fc2",
"metadata": {},
"outputs": [],
"source": [
"torch.load(ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec817e6b",
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(torch.load(ckpt_path), strict=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a208eab7",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisGLVQ2D(data=train_ds, title=\"After loading the weights\").visualize(model)"
]
},
{
"cell_type": "markdown",
"id": "f8de748f",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"## Advanced"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
2022-03-29 13:04:05 +00:00
{
"cell_type": "markdown",
"id": "53a64063",
"metadata": {},
"source": [
"### Warm-start a model with prototypes learned from another model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3177c277",
"metadata": {},
"outputs": [],
"source": [
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
"model = pt.models.SiameseGMLVQ(\n",
" dict(input_dim=2,\n",
" latent_dim=2,\n",
2022-03-29 13:04:05 +00:00
" distribution=(3, 2),\n",
" proto_lr=0.0001,\n",
" bb_lr=0.0001),\n",
" optimizer=torch.optim.Adam,\n",
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8baee9a2",
"metadata": {},
"outputs": [],
"source": [
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc203088",
"metadata": {},
"outputs": [],
"source": [
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
]
},
2021-05-18 17:41:58 +00:00
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "1f6a33a5",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
2021-05-25 18:54:07 +00:00
"### Initializing prototypes with a subset of a dataset (along with transformations)"
2022-02-15 13:38:53 +00:00
]
2021-05-25 18:54:07 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "946ce341",
"metadata": {},
"outputs": [],
2021-05-25 18:54:07 +00:00
"source": [
"import prototorch as pt\n",
"import pytorch_lightning as pl\n",
"import torch\n",
"from torchvision import transforms\n",
2022-03-29 13:04:05 +00:00
"from torchvision.datasets import MNIST\n",
"from torchvision.utils import make_grid"
2022-02-15 13:38:53 +00:00
]
2021-05-25 18:54:07 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "510d9bd4",
"metadata": {},
"outputs": [],
2021-05-25 18:54:07 +00:00
"source": [
"from matplotlib import pyplot as plt"
2022-02-15 13:38:53 +00:00
]
2021-05-25 18:54:07 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "ea7c1228",
"metadata": {},
"outputs": [],
2021-05-25 18:54:07 +00:00
"source": [
"train_ds = MNIST(\n",
" \"~/datasets\",\n",
" train=True,\n",
" download=True,\n",
" transform=transforms.Compose([\n",
" transforms.RandomHorizontalFlip(p=1.0),\n",
" transforms.RandomVerticalFlip(p=1.0),\n",
" transforms.ToTensor(),\n",
" ]),\n",
")"
2022-02-15 13:38:53 +00:00
]
2021-05-25 18:54:07 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "1b9eaf5c",
"metadata": {},
"outputs": [],
2021-05-25 18:54:07 +00:00
"source": [
"s = int(0.05 * len(train_ds))\n",
"init_ds, rest_ds = torch.utils.data.random_split(train_ds, [s, len(train_ds) - s])"
2022-02-15 13:38:53 +00:00
]
2021-05-25 18:54:07 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "8c32c9f2",
"metadata": {},
"outputs": [],
2021-05-25 18:54:07 +00:00
"source": [
"init_ds"
2022-02-15 13:38:53 +00:00
]
2021-05-25 18:54:07 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "68a9a8b9",
"metadata": {},
"outputs": [],
2021-05-25 18:54:07 +00:00
"source": [
"model = pt.models.ImageGLVQ(\n",
2022-03-29 13:04:05 +00:00
" dict(distribution=(10, 1)),\n",
2021-07-12 19:21:29 +00:00
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
2021-05-25 18:54:07 +00:00
")"
2022-02-15 13:38:53 +00:00
]
2021-05-25 18:54:07 +00:00
},
{
"cell_type": "code",
2021-07-12 19:21:29 +00:00
"execution_count": null,
2022-02-15 13:38:53 +00:00
"id": "6f23df86",
"metadata": {},
"outputs": [],
2021-05-25 18:54:07 +00:00
"source": [
2022-03-29 13:04:05 +00:00
"plt.imshow(model.get_prototype_grid(num_columns=5))"
]
},
{
"cell_type": "markdown",
"id": "1c23c7b2",
"metadata": {},
"source": [
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30780927",
"metadata": {},
"outputs": [],
"source": [
"protos, plabels = pt.components.LabeledComponents(\n",
" distribution=(10, 5),\n",
" components_initializer=pt.initializers.SMCI(init_ds),\n",
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
")()\n",
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "4fa69f92",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"## FAQs"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "fa20f9ac",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"### How do I Retrieve the prototypes and their respective labels from the model?\n",
"\n",
"For prototype models, the prototypes can be retrieved (as `torch.tensor`) as `model.prototypes`. You can convert it to a NumPy Array by calling `.numpy()` on the tensor if required.\n",
"\n",
"```python\n",
">>> model.prototypes.numpy()\n",
"```\n",
"\n",
"Similarly, the labels of the prototypes can be retrieved via `model.prototype_labels`.\n",
"\n",
"```python\n",
">>> model.prototype_labels\n",
"```"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
},
{
"cell_type": "markdown",
2022-02-15 13:38:53 +00:00
"id": "ba8215bf",
"metadata": {},
2021-05-18 17:41:58 +00:00
"source": [
"### How do I make inferences/predictions/recall with my trained model?\n",
"\n",
2021-05-25 18:54:07 +00:00
"The models under [prototorch.models](https://github.com/si-cim/prototorch_models) provide a `.predict(x)` method for making predictions. This returns the predicted class labels. It is essential that the input to this method is a `torch.tensor` and not a NumPy array. Model instances are also callable. So, you could also just say `model(x)` as if `model` were just a function. However, this returns a (pseudo)-probability distribution over the classes.\n",
2021-05-18 17:41:58 +00:00
"\n",
"#### Example\n",
"\n",
"```python\n",
2021-05-25 18:54:07 +00:00
">>> y_pred = model.predict(torch.Tensor(x_train)) # returns class labels\n",
"```\n",
"or, simply\n",
"```python\n",
">>> y_pred = model(torch.Tensor(x_train)) # returns probabilities\n",
2021-05-18 17:41:58 +00:00
"```"
2022-02-15 13:38:53 +00:00
]
2021-05-18 17:41:58 +00:00
}
],
"metadata": {
"kernelspec": {
2022-02-15 13:38:53 +00:00
"display_name": "Python 3 (ipykernel)",
2021-05-18 17:41:58 +00:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2022-03-29 13:04:05 +00:00
"version": "3.9.12"
2021-05-18 17:41:58 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}