Update Examples to new initializer architecture.
Visualization still borken for some examples.
This commit is contained in:
		@@ -4,13 +4,12 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from sklearn.datasets import make_circles
 | 
					from sklearn.datasets import make_circles
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					 | 
				
			||||||
from prototorch.models.callbacks.visualization import VisPointProtos
 | 
					 | 
				
			||||||
from prototorch.models.cbc import CBC, euclidean_similarity
 | 
					from prototorch.models.cbc import CBC, euclidean_similarity
 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
@@ -32,7 +31,7 @@ class VisualizationCallback(pl.Callback):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if self.prototype_model:
 | 
					        if self.prototype_model:
 | 
				
			||||||
            protos = pl_module.prototypes
 | 
					            protos = pl_module.components
 | 
				
			||||||
            color = pl_module.prototype_labels
 | 
					            color = pl_module.prototype_labels
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            protos = pl_module.components
 | 
					            protos = pl_module.components
 | 
				
			||||||
@@ -83,8 +82,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        input_dim=x_train.shape[1],
 | 
					        input_dim=x_train.shape[1],
 | 
				
			||||||
        nclasses=len(np.unique(y_train)),
 | 
					        nclasses=len(np.unique(y_train)),
 | 
				
			||||||
        prototypes_per_class=5,
 | 
					        num_components=5,
 | 
				
			||||||
        prototype_initializer="randn",
 | 
					        component_initializer=cinit.RandomInitializer(x_train.shape[1]),
 | 
				
			||||||
        lr=0.01,
 | 
					        lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -95,31 +94,15 @@ if __name__ == "__main__":
 | 
				
			|||||||
        similarity=euclidean_similarity,
 | 
					        similarity=euclidean_similarity,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = GLVQ(hparams, data=[x_train, y_train])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Fix the component locations
 | 
					 | 
				
			||||||
    # model.proto_layer.requires_grad_(False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # import sys
 | 
					 | 
				
			||||||
    # sys.exit()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Model summary
 | 
					 | 
				
			||||||
    print(model)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    dvis = VisPointProtos(
 | 
					    dvis = VisualizationCallback(x_train,
 | 
				
			||||||
        data=(x_train, y_train),
 | 
					                                 y_train,
 | 
				
			||||||
        save=True,
 | 
					                                 prototype_model=False,
 | 
				
			||||||
        snap=False,
 | 
					                                 title="CBC Circle Example")
 | 
				
			||||||
        voronoi=True,
 | 
					 | 
				
			||||||
        resolution=50,
 | 
					 | 
				
			||||||
        pause_time=0.1,
 | 
					 | 
				
			||||||
        make_gif=True,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        max_epochs=10,
 | 
					        max_epochs=50,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            dvis,
 | 
					            dvis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,30 +4,38 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					from prototorch.models.cbc import CBC, euclidean_similarity
 | 
				
			||||||
from prototorch.models.cbc import CBC
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
        x_train,
 | 
					        x_train,
 | 
				
			||||||
        y_train,
 | 
					        y_train,
 | 
				
			||||||
 | 
					        prototype_model=True,
 | 
				
			||||||
        title="Prototype Visualization",
 | 
					        title="Prototype Visualization",
 | 
				
			||||||
                 cmap="viridis"):
 | 
					        cmap="viridis",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.x_train = x_train
 | 
					        self.x_train = x_train
 | 
				
			||||||
        self.y_train = y_train
 | 
					        self.y_train = y_train
 | 
				
			||||||
        self.title = title
 | 
					        self.title = title
 | 
				
			||||||
        self.fig = plt.figure(self.title)
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
        self.cmap = cmap
 | 
					        self.cmap = cmap
 | 
				
			||||||
 | 
					        self.prototype_model = prototype_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        # protos = pl_module.prototypes
 | 
					        if self.prototype_model:
 | 
				
			||||||
            protos = pl_module.components
 | 
					            protos = pl_module.components
 | 
				
			||||||
        # plabels = pl_module.prototype_labels
 | 
					            color = pl_module.prototype_labels
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            protos = pl_module.components
 | 
				
			||||||
 | 
					            color = "k"
 | 
				
			||||||
        ax = self.fig.gca()
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
        ax.cla()
 | 
					        ax.cla()
 | 
				
			||||||
        ax.set_title(self.title)
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
@@ -37,8 +45,7 @@ class VisualizationCallback(pl.Callback):
 | 
				
			|||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            protos[:, 0],
 | 
					            protos[:, 0],
 | 
				
			||||||
            protos[:, 1],
 | 
					            protos[:, 1],
 | 
				
			||||||
            # c=plabels,
 | 
					            c=color,
 | 
				
			||||||
            c="k",
 | 
					 | 
				
			||||||
            cmap=self.cmap,
 | 
					            cmap=self.cmap,
 | 
				
			||||||
            edgecolor="k",
 | 
					            edgecolor="k",
 | 
				
			||||||
            marker="D",
 | 
					            marker="D",
 | 
				
			||||||
@@ -71,42 +78,31 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        input_dim=x_train.shape[1],
 | 
					        input_dim=x_train.shape[1],
 | 
				
			||||||
        nclasses=3,
 | 
					        nclasses=len(np.unique(y_train)),
 | 
				
			||||||
        prototypes_per_class=3,
 | 
					        num_components=9,
 | 
				
			||||||
        prototype_initializer="stratified_mean",
 | 
					        component_initializer=cinit.StratifiedMeanInitializer(
 | 
				
			||||||
 | 
					            torch.Tensor(x_train), torch.Tensor(y_train)),
 | 
				
			||||||
        lr=0.01,
 | 
					        lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = CBC(hparams, data=[x_train, y_train])
 | 
					    model = CBC(
 | 
				
			||||||
 | 
					        hparams,
 | 
				
			||||||
    # Fix the component locations
 | 
					        data=[x_train, y_train],
 | 
				
			||||||
    # model.proto_layer.requires_grad_(False)
 | 
					        similarity=euclidean_similarity,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    # Pure-positive reasonings
 | 
					 | 
				
			||||||
    ncomps = 3
 | 
					 | 
				
			||||||
    nclasses = 3
 | 
					 | 
				
			||||||
    rmat = torch.stack(
 | 
					 | 
				
			||||||
        [0.9 * torch.eye(ncomps),
 | 
					 | 
				
			||||||
         torch.zeros(ncomps, nclasses)], dim=0)
 | 
					 | 
				
			||||||
    # model.reasoning_layer.load_state_dict({"reasoning_probabilities": rmat},
 | 
					 | 
				
			||||||
    #                                       strict=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    print(model.reasoning_layer.reasoning_probabilities)
 | 
					 | 
				
			||||||
    # import sys
 | 
					 | 
				
			||||||
    # sys.exit()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Model summary
 | 
					 | 
				
			||||||
    print(model)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = VisualizationCallback(x_train, y_train)
 | 
					    dvis = VisualizationCallback(x_train,
 | 
				
			||||||
 | 
					                                 y_train,
 | 
				
			||||||
 | 
					                                 prototype_model=False,
 | 
				
			||||||
 | 
					                                 title="CBC Iris Example")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        max_epochs=100,
 | 
					        max_epochs=50,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            dvis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,9 +4,9 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					 | 
				
			||||||
from prototorch.models.cbc import CBC
 | 
					from prototorch.models.cbc import CBC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -110,7 +110,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Pure-positive reasonings
 | 
					    # Pure-positive reasonings
 | 
				
			||||||
    new_reasoning = torch.zeros_like(
 | 
					    new_reasoning = torch.zeros_like(
 | 
				
			||||||
        model.reasoning_layer.reasoning_probabilities)
 | 
					        model.reasoning_layer.reasoning_probabilities)
 | 
				
			||||||
    for i, label in enumerate(model.proto_layer.prototype_labels):
 | 
					    for i, label in enumerate(model.component_layer.prototype_labels):
 | 
				
			||||||
        new_reasoning[0][0][i][int(label)] = 1.0
 | 
					        new_reasoning[0][0][i][int(label)] = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model.reasoning_layer.reasoning_probabilities.data = new_reasoning
 | 
					    model.reasoning_layer.reasoning_probabilities.data = new_reasoning
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,9 +8,9 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					 | 
				
			||||||
from prototorch.models.cbc import CBC
 | 
					from prototorch.models.cbc import CBC
 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					from prototorch.models.glvq import GLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -132,11 +132,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
    train(glvq_model, x_train, y_train, train_loader, epochs=10)
 | 
					    train(glvq_model, x_train, y_train, train_loader, epochs=10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Transfer Prototypes
 | 
					    # Transfer Prototypes
 | 
				
			||||||
    cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict())
 | 
					    cbc_model.component_layer.load_state_dict(
 | 
				
			||||||
 | 
					        glvq_model.proto_layer.state_dict())
 | 
				
			||||||
    # Pure-positive reasonings
 | 
					    # Pure-positive reasonings
 | 
				
			||||||
    new_reasoning = torch.zeros_like(
 | 
					    new_reasoning = torch.zeros_like(
 | 
				
			||||||
        cbc_model.reasoning_layer.reasoning_probabilities)
 | 
					        cbc_model.reasoning_layer.reasoning_probabilities)
 | 
				
			||||||
    for i, label in enumerate(cbc_model.proto_layer.prototype_labels):
 | 
					    for i, label in enumerate(cbc_model.component_layer.prototype_labels):
 | 
				
			||||||
        new_reasoning[0][0][i][int(label)] = 1.0
 | 
					        new_reasoning[0][0][i][int(label)] = 1.0
 | 
				
			||||||
        new_reasoning[1][0][i][1 - int(label)] = 1.0
 | 
					        new_reasoning[1][0][i][1 - int(label)] = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,86 +1,16 @@
 | 
				
			|||||||
"""GLVQ example using the Iris dataset."""
 | 
					"""GLVQ example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					from prototorch.models.callbacks.visualization import VisGLVQ2D
 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					from prototorch.models.glvq import GLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
class GLVQIris(GLVQ):
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
    def add_model_specific_args(parent_parser):
 | 
					 | 
				
			||||||
        parser = argparse.ArgumentParser(parents=[parent_parser],
 | 
					 | 
				
			||||||
                                         add_help=False)
 | 
					 | 
				
			||||||
        parser.add_argument("--epochs", type=int, default=1)
 | 
					 | 
				
			||||||
        parser.add_argument("--lr", type=float, default=1e-1)
 | 
					 | 
				
			||||||
        parser.add_argument("--batch_size", type=int, default=150)
 | 
					 | 
				
			||||||
        parser.add_argument("--input_dim", type=int, default=2)
 | 
					 | 
				
			||||||
        parser.add_argument("--nclasses", type=int, default=3)
 | 
					 | 
				
			||||||
        parser.add_argument("--prototypes_per_class", type=int, default=3)
 | 
					 | 
				
			||||||
        parser.add_argument("--prototype_initializer",
 | 
					 | 
				
			||||||
                            type=str,
 | 
					 | 
				
			||||||
                            default="stratified_mean")
 | 
					 | 
				
			||||||
        return parser
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 x_train,
 | 
					 | 
				
			||||||
                 y_train,
 | 
					 | 
				
			||||||
                 title="Prototype Visualization",
 | 
					 | 
				
			||||||
                 cmap="viridis"):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.x_train = x_train
 | 
					 | 
				
			||||||
        self.y_train = y_train
 | 
					 | 
				
			||||||
        self.title = title
 | 
					 | 
				
			||||||
        self.fig = plt.figure(self.title)
 | 
					 | 
				
			||||||
        self.cmap = cmap
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					 | 
				
			||||||
        plabels = pl_module.prototype_labels
 | 
					 | 
				
			||||||
        ax = self.fig.gca()
 | 
					 | 
				
			||||||
        ax.cla()
 | 
					 | 
				
			||||||
        ax.set_title(self.title)
 | 
					 | 
				
			||||||
        ax.set_xlabel("Data dimension 1")
 | 
					 | 
				
			||||||
        ax.set_ylabel("Data dimension 2")
 | 
					 | 
				
			||||||
        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
					 | 
				
			||||||
        ax.scatter(
 | 
					 | 
				
			||||||
            protos[:, 0],
 | 
					 | 
				
			||||||
            protos[:, 1],
 | 
					 | 
				
			||||||
            c=plabels,
 | 
					 | 
				
			||||||
            cmap=self.cmap,
 | 
					 | 
				
			||||||
            edgecolor="k",
 | 
					 | 
				
			||||||
            marker="D",
 | 
					 | 
				
			||||||
            s=50,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
					 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
					 | 
				
			||||||
        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
					 | 
				
			||||||
                             np.arange(y_min, y_max, 1 / 50))
 | 
					 | 
				
			||||||
        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
					 | 
				
			||||||
        y_pred = pl_module.predict(torch.Tensor(mesh_input))
 | 
					 | 
				
			||||||
        y_pred = y_pred.reshape(xx.shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					 | 
				
			||||||
        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
					 | 
				
			||||||
        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
					 | 
				
			||||||
        plt.pause(0.1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # For best-practices when using `argparse` with `pytorch_lightning`, see
 | 
					 | 
				
			||||||
    # https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html
 | 
					 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					    x_train, y_train = load_iris(return_X_y=True)
 | 
				
			||||||
    x_train = x_train[:, [0, 2]]
 | 
					    x_train = x_train[:, [0, 2]]
 | 
				
			||||||
@@ -89,43 +19,23 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Dataloaders
 | 
					    # Dataloaders
 | 
				
			||||||
    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
					    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Add model specific args
 | 
					    # Hyperparameters
 | 
				
			||||||
    parser = GLVQIris.add_model_specific_args(parser)
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        nclasses=3,
 | 
				
			||||||
    # Callbacks
 | 
					        prototypes_per_class=2,
 | 
				
			||||||
    vis = VisualizationCallback(x_train, y_train)
 | 
					        prototype_initializer=cinit.StratifiedMeanInitializer(
 | 
				
			||||||
 | 
					            torch.Tensor(x_train), torch.Tensor(y_train)),
 | 
				
			||||||
    # Automatically add trainer-specific-args like `--gpus`, `--num_nodes` etc.
 | 
					        lr=0.01,
 | 
				
			||||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Setup trainer
 | 
					 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					 | 
				
			||||||
        parser,
 | 
					 | 
				
			||||||
        max_epochs=10,
 | 
					 | 
				
			||||||
        callbacks=[
 | 
					 | 
				
			||||||
            vis,
 | 
					 | 
				
			||||||
        ],  # comment this line out to disable the visualization
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    # trainer.tune(model)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    model = GLVQ(hparams)
 | 
				
			||||||
    model = GLVQIris(args, data=[x_train, y_train])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Model summary
 | 
					    # Setup trainer
 | 
				
			||||||
    print(model)
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
 | 
					        max_epochs=50,
 | 
				
			||||||
 | 
					        callbacks=[VisGLVQ2D(x_train, y_train)],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Save the model manually (use `pl.callbacks.ModelCheckpoint` to automate)
 | 
					 | 
				
			||||||
    ckpt = "glvq_iris.ckpt"
 | 
					 | 
				
			||||||
    trainer.save_checkpoint(ckpt)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Load the checkpoint
 | 
					 | 
				
			||||||
    new_model = GLVQIris.load_from_checkpoint(checkpoint_path=ckpt)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    print(new_model)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Continue training
 | 
					 | 
				
			||||||
    trainer.fit(new_model, train_loader)  # TODO See why this fails!
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,40 +0,0 @@
 | 
				
			|||||||
"""GLVQ example using the Iris dataset."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from prototorch.components import initializers as cinit
 | 
					 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					 | 
				
			||||||
from prototorch.models.callbacks.visualization import VisGLVQ2D
 | 
					 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    # Dataset
 | 
					 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					 | 
				
			||||||
    x_train = x_train[:, [0, 2]]
 | 
					 | 
				
			||||||
    train_ds = NumpyDataset(x_train, y_train)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Dataloaders
 | 
					 | 
				
			||||||
    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Hyperparameters
 | 
					 | 
				
			||||||
    hparams = dict(
 | 
					 | 
				
			||||||
        nclasses=3,
 | 
					 | 
				
			||||||
        prototypes_per_class=2,
 | 
					 | 
				
			||||||
        prototype_initializer=cinit.StratifiedMeanInitializer(
 | 
					 | 
				
			||||||
            torch.Tensor(x_train), torch.Tensor(y_train)),
 | 
					 | 
				
			||||||
        lr=0.01,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Initialize the model
 | 
					 | 
				
			||||||
    model = GLVQ(hparams)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Setup trainer
 | 
					 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					 | 
				
			||||||
        max_epochs=50,
 | 
					 | 
				
			||||||
        callbacks=[VisGLVQ2D(x_train, y_train)],
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Training loop
 | 
					 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					 | 
				
			||||||
@@ -7,6 +7,7 @@ import argparse
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torchvision
 | 
					import torchvision
 | 
				
			||||||
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
from torchvision.datasets import MNIST
 | 
					from torchvision.datasets import MNIST
 | 
				
			||||||
@@ -92,12 +93,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
        input_dim=28 * 28,
 | 
					        input_dim=28 * 28,
 | 
				
			||||||
        nclasses=10,
 | 
					        nclasses=10,
 | 
				
			||||||
        prototypes_per_class=1,
 | 
					        prototypes_per_class=1,
 | 
				
			||||||
        prototype_initializer="stratified_mean",
 | 
					        prototype_initializer=cinit.StratifiedMeanInitializer(x, y),
 | 
				
			||||||
        lr=args.lr,
 | 
					        lr=args.lr,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = ImageGLVQ(hparams, data=[x, y])
 | 
					    model = ImageGLVQ(hparams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Model summary
 | 
					    # Model summary
 | 
				
			||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,9 +5,10 @@ import torch
 | 
				
			|||||||
from prototorch.components import initializers as cinit
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from prototorch.datasets.spiral import make_spiral
 | 
					from prototorch.datasets.spiral import make_spiral
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.models.callbacks.visualization import VisGLVQ2D
 | 
					from prototorch.models.callbacks.visualization import VisGLVQ2D
 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					from prototorch.models.glvq import GLVQ
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StopOnNaN(pl.Callback):
 | 
					class StopOnNaN(pl.Callback):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,11 +4,12 @@ import pytorch_lightning as pl
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.components import initializers as cinit
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
 | 
					 | 
				
			||||||
from prototorch.models.glvq import GMLVQ
 | 
					 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
 | 
				
			||||||
 | 
					from prototorch.models.glvq import GMLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					    x_train, y_train = load_iris(return_X_y=True)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,12 @@
 | 
				
			|||||||
"""GMLVQ example using the Tecator dataset."""
 | 
					"""GMLVQ example using the Tecator dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from prototorch.components import initializers as cinit
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
from prototorch.datasets.tecator import Tecator
 | 
					from prototorch.datasets.tecator import Tecator
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
 | 
					from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
 | 
				
			||||||
from prototorch.models.glvq import GMLVQ
 | 
					from prototorch.models.glvq import GMLVQ
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,15 +1,14 @@
 | 
				
			|||||||
"""Neural Gas example using the Iris dataset."""
 | 
					"""Neural Gas example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from prototorch.models.callbacks.visualization import VisNG2D
 | 
					 | 
				
			||||||
from prototorch.models.neural_gas import NeuralGas
 | 
					 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from sklearn.preprocessing import StandardScaler
 | 
					from sklearn.preprocessing import StandardScaler
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.models.callbacks.visualization import VisNG2D
 | 
				
			||||||
 | 
					from prototorch.models.neural_gas import NeuralGas
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
					    x_train, y_train = load_iris(return_X_y=True)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,14 +2,16 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.components import (StratifiedMeanInitializer,
 | 
					from prototorch.components import (
 | 
				
			||||||
                                   StratifiedSelectionInitializer)
 | 
					    StratifiedMeanInitializer
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from prototorch.datasets.abstract import NumpyDataset
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
 | 
					 | 
				
			||||||
from prototorch.models.glvq import SiameseGLVQ
 | 
					 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
 | 
				
			||||||
 | 
					from prototorch.models.glvq import SiameseGLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Backbone(torch.nn.Module):
 | 
					class Backbone(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
					    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,10 +1,9 @@
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					from prototorch.components.components import Components
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					from prototorch.functions.distances import euclidean_distance
 | 
				
			||||||
from prototorch.functions.similarities import cosine_similarity
 | 
					from prototorch.functions.similarities import cosine_similarity
 | 
				
			||||||
from prototorch.modules.prototypes import Prototypes1D
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rescaled_cosine_similarity(x, y):
 | 
					def rescaled_cosine_similarity(x, y):
 | 
				
			||||||
@@ -93,12 +92,8 @@ class CBC(pl.LightningModule):
 | 
				
			|||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.save_hyperparameters(hparams)
 | 
					        self.save_hyperparameters(hparams)
 | 
				
			||||||
        self.margin = margin
 | 
					        self.margin = margin
 | 
				
			||||||
        self.proto_layer = Prototypes1D(
 | 
					        self.component_layer = Components(self.hparams.num_components,
 | 
				
			||||||
            input_dim=self.hparams.input_dim,
 | 
					                                          self.hparams.component_initializer)
 | 
				
			||||||
            nclasses=self.hparams.nclasses,
 | 
					 | 
				
			||||||
            prototypes_per_class=self.hparams.prototypes_per_class,
 | 
					 | 
				
			||||||
            prototype_initializer=self.hparams.prototype_initializer,
 | 
					 | 
				
			||||||
            **kwargs)
 | 
					 | 
				
			||||||
        # self.similarity = CosineSimilarity()
 | 
					        # self.similarity = CosineSimilarity()
 | 
				
			||||||
        self.similarity = similarity
 | 
					        self.similarity = similarity
 | 
				
			||||||
        self.backbone = backbone_class()
 | 
					        self.backbone = backbone_class()
 | 
				
			||||||
@@ -110,7 +105,7 @@ class CBC(pl.LightningModule):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def components(self):
 | 
					    def components(self):
 | 
				
			||||||
        return self.proto_layer.prototypes.detach().cpu()
 | 
					        return self.component_layer.components.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def reasonings(self):
 | 
					    def reasonings(self):
 | 
				
			||||||
@@ -126,7 +121,7 @@ class CBC(pl.LightningModule):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        self.sync_backbones()
 | 
					        self.sync_backbones()
 | 
				
			||||||
        protos, _ = self.proto_layer()
 | 
					        protos = self.component_layer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        latent_protos = self.backbone_dependent(protos)
 | 
					        latent_protos = self.backbone_dependent(protos)
 | 
				
			||||||
@@ -167,4 +162,4 @@ class ImageCBC(CBC):
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
					    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
				
			||||||
        # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
 | 
					        # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
 | 
				
			||||||
        self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
 | 
					        self.component_layer.prototypes.data.clamp_(0.0, 1.0)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,3 @@
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
from prototorch.components import LabeledComponents
 | 
					from prototorch.components import LabeledComponents
 | 
				
			||||||
@@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac
 | 
				
			|||||||
from prototorch.functions.distances import (euclidean_distance,
 | 
					from prototorch.functions.distances import (euclidean_distance,
 | 
				
			||||||
                                            squared_euclidean_distance)
 | 
					                                            squared_euclidean_distance)
 | 
				
			||||||
from prototorch.functions.losses import glvq_loss
 | 
					from prototorch.functions.losses import glvq_loss
 | 
				
			||||||
from prototorch.modules.prototypes import Prototypes1D
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import AbstractPrototypeModel
 | 
					from .abstract import AbstractPrototypeModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel):
 | 
				
			|||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            preds = wtac(dis, plabels)
 | 
					            preds = wtac(dis, plabels)
 | 
				
			||||||
        # `.int()` because FloatTensors are assumed to be class probabilities
 | 
					        # `.int()` because FloatTensors are assumed to be class probabilities
 | 
				
			||||||
        self.train_acc(preds.int(), y.int())
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Logging
 | 
					        # Logging
 | 
				
			||||||
        self.log("train_loss", loss)
 | 
					        self.log("train_loss", loss)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,9 +1,7 @@
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.components import Components
 | 
					from prototorch.components import Components
 | 
				
			||||||
from prototorch.components import initializers as cinit
 | 
					from prototorch.components import initializers as cinit
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					from prototorch.functions.distances import euclidean_distance
 | 
				
			||||||
from prototorch.modules import Prototypes1D
 | 
					 | 
				
			||||||
from prototorch.modules.losses import NeuralGasEnergy
 | 
					from prototorch.modules.losses import NeuralGasEnergy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import AbstractPrototypeModel
 | 
					from .abstract import AbstractPrototypeModel
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user