docs: update tutorial notebook
This commit is contained in:
parent
4941c2b89d
commit
40bd7ed380
@ -21,7 +21,7 @@
|
||||
"id": "43b74278",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
|
||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
|
||||
"\n",
|
||||
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
||||
"\n",
|
||||
@ -404,6 +404,55 @@
|
||||
"## Advanced"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "53a64063",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Warm-start a model with prototypes learned from another model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3177c277",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
|
||||
"model = pt.models.SiameseGMLVQ(\n",
|
||||
" dict(input_dim=2,\n",
|
||||
" output_dim=2,\n",
|
||||
" distribution=(3, 2),\n",
|
||||
" proto_lr=0.0001,\n",
|
||||
" bb_lr=0.0001),\n",
|
||||
" optimizer=torch.optim.Adam,\n",
|
||||
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
|
||||
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
|
||||
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8baee9a2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc203088",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1f6a33a5",
|
||||
@ -423,7 +472,8 @@
|
||||
"import pytorch_lightning as pl\n",
|
||||
"import torch\n",
|
||||
"from torchvision import transforms\n",
|
||||
"from torchvision.datasets import MNIST"
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision.utils import make_grid"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -484,7 +534,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = pt.models.ImageGLVQ(\n",
|
||||
" dict(distribution=(10, 5)),\n",
|
||||
" dict(distribution=(10, 1)),\n",
|
||||
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
")"
|
||||
]
|
||||
@ -496,7 +546,30 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.imshow(model.get_prototype_grid(num_columns=10))"
|
||||
"plt.imshow(model.get_prototype_grid(num_columns=5))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c23c7b2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "30780927",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"protos, plabels = pt.components.LabeledComponents(\n",
|
||||
" distribution=(10, 5),\n",
|
||||
" components_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
|
||||
")()\n",
|
||||
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -564,7 +637,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.9"
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
Loading…
Reference in New Issue
Block a user