docs: update tutorial notebook
This commit is contained in:
parent
4941c2b89d
commit
40bd7ed380
@ -21,7 +21,7 @@
|
|||||||
"id": "43b74278",
|
"id": "43b74278",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework.\n",
|
"This is a short tutorial for the [models](https://github.com/si-cim/prototorch_models) plugin of the [ProtoTorch](https://github.com/si-cim/prototorch) framework. This is by no means a comprehensive look at all the features that the framework has to offer, but it should help you get up and running.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
"[ProtoTorch](https://github.com/si-cim/prototorch) provides [torch.nn](https://pytorch.org/docs/stable/nn.html) modules and utilities to implement prototype-based models. However, it is up to the user to put these modules together into models and handle the training of these models. Expert machine-learning practioners and researchers sometimes prefer this level of control. However, this leads to a lot of boilerplate code that is essentially same across many projects. Needless to say, this is a source of a lot of frustration. [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) is a framework that helps avoid a lot of this frustration by handling the boilerplate code for you so you don't have to reinvent the wheel every time you need to implement a new model.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -404,6 +404,55 @@
|
|||||||
"## Advanced"
|
"## Advanced"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "53a64063",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Warm-start a model with prototypes learned from another model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3177c277",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"trained_model = pt.models.GLVQ.load_from_checkpoint(\"./checkpoints/glvq_iris.ckpt\", strict=False)\n",
|
||||||
|
"model = pt.models.SiameseGMLVQ(\n",
|
||||||
|
" dict(input_dim=2,\n",
|
||||||
|
" output_dim=2,\n",
|
||||||
|
" distribution=(3, 2),\n",
|
||||||
|
" proto_lr=0.0001,\n",
|
||||||
|
" bb_lr=0.0001),\n",
|
||||||
|
" optimizer=torch.optim.Adam,\n",
|
||||||
|
" prototypes_initializer=pt.initializers.LCI(trained_model.prototypes),\n",
|
||||||
|
" labels_initializer=pt.initializers.LLI(trained_model.prototype_labels),\n",
|
||||||
|
" omega_initializer=pt.initializers.LLTI(torch.tensor([[0., 1.], [1., 0.]])), # permute axes\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8baee9a2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(model)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "cc203088",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"pt.models.VisSiameseGLVQ2D(data=train_ds, title=\"GMLVQ - Warm-start state\").visualize(model)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "1f6a33a5",
|
"id": "1f6a33a5",
|
||||||
@ -423,7 +472,8 @@
|
|||||||
"import pytorch_lightning as pl\n",
|
"import pytorch_lightning as pl\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from torchvision import transforms\n",
|
"from torchvision import transforms\n",
|
||||||
"from torchvision.datasets import MNIST"
|
"from torchvision.datasets import MNIST\n",
|
||||||
|
"from torchvision.utils import make_grid"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -484,7 +534,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model = pt.models.ImageGLVQ(\n",
|
"model = pt.models.ImageGLVQ(\n",
|
||||||
" dict(distribution=(10, 5)),\n",
|
" dict(distribution=(10, 1)),\n",
|
||||||
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
" prototypes_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
@ -496,7 +546,30 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"plt.imshow(model.get_prototype_grid(num_columns=10))"
|
"plt.imshow(model.get_prototype_grid(num_columns=5))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1c23c7b2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We could, of course, just use the initializers in isolation. For example, we could quickly obtain a stratified selection from the data like so:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "30780927",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"protos, plabels = pt.components.LabeledComponents(\n",
|
||||||
|
" distribution=(10, 5),\n",
|
||||||
|
" components_initializer=pt.initializers.SMCI(init_ds),\n",
|
||||||
|
" labels_initializer=pt.initializers.LabelsInitializer(),\n",
|
||||||
|
")()\n",
|
||||||
|
"plt.imshow(make_grid(protos, 10).permute(1, 2, 0)[:, :, 0], cmap=\"jet\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -564,7 +637,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.9"
|
"version": "3.9.12"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
Loading…
Reference in New Issue
Block a user