Automatic Formating.
This commit is contained in:
		@@ -4,26 +4,24 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from prototorch.models.cbc import CBC, rescaled_cosine_similarity, euclidean_similarity
 | 
					 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					 | 
				
			||||||
from sklearn.datasets import make_circles
 | 
					from sklearn.datasets import make_circles
 | 
				
			||||||
from torch.utils.data import DataLoader, TensorDataset
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
class NumpyDataset(TensorDataset):
 | 
					from prototorch.models.callbacks.visualization import VisPointProtos
 | 
				
			||||||
    def __init__(self, *arrays):
 | 
					from prototorch.models.cbc import CBC, euclidean_similarity
 | 
				
			||||||
        # tensors = [torch.from_numpy(arr) for arr in arrays]
 | 
					from prototorch.models.glvq import GLVQ
 | 
				
			||||||
        tensors = [torch.Tensor(arr) for arr in arrays]
 | 
					 | 
				
			||||||
        super().__init__(*tensors)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(
 | 
				
			||||||
                 x_train,
 | 
					        self,
 | 
				
			||||||
                 y_train,
 | 
					        x_train,
 | 
				
			||||||
                 prototype_model=True,
 | 
					        y_train,
 | 
				
			||||||
                 title="Prototype Visualization",
 | 
					        prototype_model=True,
 | 
				
			||||||
                 cmap="viridis"):
 | 
					        title="Prototype Visualization",
 | 
				
			||||||
 | 
					        cmap="viridis",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.x_train = x_train
 | 
					        self.x_train = x_train
 | 
				
			||||||
        self.y_train = y_train
 | 
					        self.y_train = y_train
 | 
				
			||||||
@@ -38,20 +36,22 @@ class VisualizationCallback(pl.Callback):
 | 
				
			|||||||
            color = pl_module.prototype_labels
 | 
					            color = pl_module.prototype_labels
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            protos = pl_module.components
 | 
					            protos = pl_module.components
 | 
				
			||||||
            color = 'k'
 | 
					            color = "k"
 | 
				
			||||||
        ax = self.fig.gca()
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
        ax.cla()
 | 
					        ax.cla()
 | 
				
			||||||
        ax.set_title(self.title)
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
        ax.set_xlabel("Data dimension 1")
 | 
					        ax.set_xlabel("Data dimension 1")
 | 
				
			||||||
        ax.set_ylabel("Data dimension 2")
 | 
					        ax.set_ylabel("Data dimension 2")
 | 
				
			||||||
        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
        ax.scatter(protos[:, 0],
 | 
					        ax.scatter(
 | 
				
			||||||
                   protos[:, 1],
 | 
					            protos[:, 0],
 | 
				
			||||||
                   c=color,
 | 
					            protos[:, 1],
 | 
				
			||||||
                   cmap=self.cmap,
 | 
					            c=color,
 | 
				
			||||||
                   edgecolor="k",
 | 
					            cmap=self.cmap,
 | 
				
			||||||
                   marker="D",
 | 
					            edgecolor="k",
 | 
				
			||||||
                   s=50)
 | 
					            marker="D",
 | 
				
			||||||
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
					        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
					        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
				
			||||||
@@ -95,7 +95,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
        similarity=euclidean_similarity,
 | 
					        similarity=euclidean_similarity,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #model = GLVQ(hparams, data=[x_train, y_train])
 | 
					    model = GLVQ(hparams, data=[x_train, y_train])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Fix the component locations
 | 
					    # Fix the component locations
 | 
				
			||||||
    # model.proto_layer.requires_grad_(False)
 | 
					    # model.proto_layer.requires_grad_(False)
 | 
				
			||||||
@@ -107,13 +107,21 @@ if __name__ == "__main__":
 | 
				
			|||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = VisualizationCallback(x_train, y_train, prototype_model=False)
 | 
					    dvis = VisPointProtos(
 | 
				
			||||||
 | 
					        data=(x_train, y_train),
 | 
				
			||||||
 | 
					        save=True,
 | 
				
			||||||
 | 
					        snap=False,
 | 
				
			||||||
 | 
					        voronoi=True,
 | 
				
			||||||
 | 
					        resolution=50,
 | 
				
			||||||
 | 
					        pause_time=0.1,
 | 
				
			||||||
 | 
					        make_gif=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
        max_epochs=500,
 | 
					        max_epochs=10,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,
 | 
					            dvis,
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,16 +4,11 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from prototorch.models.cbc import CBC
 | 
					 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from torch.utils.data import DataLoader, TensorDataset
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
class NumpyDataset(TensorDataset):
 | 
					from prototorch.models.cbc import CBC
 | 
				
			||||||
    def __init__(self, *arrays):
 | 
					 | 
				
			||||||
        # tensors = [torch.from_numpy(arr) for arr in arrays]
 | 
					 | 
				
			||||||
        tensors = [torch.Tensor(arr) for arr in arrays]
 | 
					 | 
				
			||||||
        super().__init__(*tensors)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
@@ -47,7 +42,8 @@ class VisualizationCallback(pl.Callback):
 | 
				
			|||||||
            cmap=self.cmap,
 | 
					            cmap=self.cmap,
 | 
				
			||||||
            edgecolor="k",
 | 
					            edgecolor="k",
 | 
				
			||||||
            marker="D",
 | 
					            marker="D",
 | 
				
			||||||
            s=50)
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
					        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
					        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
				
			||||||
@@ -73,11 +69,13 @@ if __name__ == "__main__":
 | 
				
			|||||||
    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
					    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(input_dim=x_train.shape[1],
 | 
					    hparams = dict(
 | 
				
			||||||
                   nclasses=3,
 | 
					        input_dim=x_train.shape[1],
 | 
				
			||||||
                   prototypes_per_class=3,
 | 
					        nclasses=3,
 | 
				
			||||||
                   prototype_initializer="stratified_mean",
 | 
					        prototypes_per_class=3,
 | 
				
			||||||
                   lr=0.01)
 | 
					        prototype_initializer="stratified_mean",
 | 
				
			||||||
 | 
					        lr=0.01,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = CBC(hparams, data=[x_train, y_train])
 | 
					    model = CBC(hparams, data=[x_train, y_train])
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,12 +7,12 @@ import argparse
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torchvision
 | 
					import torchvision
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					 | 
				
			||||||
from prototorch.models.cbc import ImageCBC, euclidean_similarity, rescaled_cosine_similarity
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
from torchvision.datasets import MNIST
 | 
					from torchvision.datasets import MNIST
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.models.cbc import CBC, ImageCBC, euclidean_similarity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
    def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
 | 
					    def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
 | 
				
			||||||
@@ -89,8 +89,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataloaders
 | 
					    # Dataloaders
 | 
				
			||||||
    train_loader = DataLoader(mnist_train, batch_size=1024)
 | 
					    train_loader = DataLoader(mnist_train, batch_size=32)
 | 
				
			||||||
    test_loader = DataLoader(mnist_test, batch_size=1024)
 | 
					    test_loader = DataLoader(mnist_test, batch_size=32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Grab the full dataset to warm-start prototypes
 | 
					    # Grab the full dataset to warm-start prototypes
 | 
				
			||||||
    x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
 | 
					    x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
 | 
				
			||||||
@@ -102,12 +102,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
        nclasses=10,
 | 
					        nclasses=10,
 | 
				
			||||||
        prototypes_per_class=args.ppc,
 | 
					        prototypes_per_class=args.ppc,
 | 
				
			||||||
        prototype_initializer="randn",
 | 
					        prototype_initializer="randn",
 | 
				
			||||||
        lr=1,
 | 
					        lr=0.01,
 | 
				
			||||||
        similarity=euclidean_similarity,
 | 
					        similarity=euclidean_similarity,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = ImageCBC(hparams, data=[x, y])
 | 
					    model = CBC(hparams, data=[x, y])
 | 
				
			||||||
    # Model summary
 | 
					    # Model summary
 | 
				
			||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										135
									
								
								examples/cbc_spiral.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								examples/cbc_spiral.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,135 @@
 | 
				
			|||||||
 | 
					"""CBC example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
 | 
					from prototorch.models.cbc import CBC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x_train,
 | 
				
			||||||
 | 
					        y_train,
 | 
				
			||||||
 | 
					        prototype_model=True,
 | 
				
			||||||
 | 
					        title="Prototype Visualization",
 | 
				
			||||||
 | 
					        cmap="viridis",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.x_train = x_train
 | 
				
			||||||
 | 
					        self.y_train = y_train
 | 
				
			||||||
 | 
					        self.title = title
 | 
				
			||||||
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
 | 
					        self.cmap = cmap
 | 
				
			||||||
 | 
					        self.prototype_model = prototype_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        if self.prototype_model:
 | 
				
			||||||
 | 
					            protos = pl_module.prototypes
 | 
				
			||||||
 | 
					            color = pl_module.prototype_labels
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            protos = pl_module.components
 | 
				
			||||||
 | 
					            color = "k"
 | 
				
			||||||
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
 | 
					        ax.cla()
 | 
				
			||||||
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
 | 
					        ax.set_xlabel("Data dimension 1")
 | 
				
			||||||
 | 
					        ax.set_ylabel("Data dimension 2")
 | 
				
			||||||
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
 | 
					        ax.scatter(
 | 
				
			||||||
 | 
					            protos[:, 0],
 | 
				
			||||||
 | 
					            protos[:, 1],
 | 
				
			||||||
 | 
					            c=color,
 | 
				
			||||||
 | 
					            cmap=self.cmap,
 | 
				
			||||||
 | 
					            edgecolor="k",
 | 
				
			||||||
 | 
					            marker="D",
 | 
				
			||||||
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
 | 
					        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
				
			||||||
 | 
					        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
				
			||||||
 | 
					        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
				
			||||||
 | 
					                             np.arange(y_min, y_max, 1 / 50))
 | 
				
			||||||
 | 
					        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
				
			||||||
 | 
					        y_pred = pl_module.predict(torch.Tensor(mesh_input))
 | 
				
			||||||
 | 
					        y_pred = y_pred.reshape(xx.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
				
			||||||
 | 
					        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
				
			||||||
 | 
					        plt.pause(0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def make_spirals(n_samples=500, noise=0.3):
 | 
				
			||||||
 | 
					    def get_samples(n, delta_t):
 | 
				
			||||||
 | 
					        points = []
 | 
				
			||||||
 | 
					        for i in range(n):
 | 
				
			||||||
 | 
					            r = i / n_samples * 5
 | 
				
			||||||
 | 
					            t = 1.75 * i / n * 2 * np.pi + delta_t
 | 
				
			||||||
 | 
					            x = r * np.sin(t) + np.random.rand(1) * noise
 | 
				
			||||||
 | 
					            y = r * np.cos(t) + np.random.rand(1) * noise
 | 
				
			||||||
 | 
					            points.append([x, y])
 | 
				
			||||||
 | 
					        return points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n = n_samples // 2
 | 
				
			||||||
 | 
					    positive = get_samples(n=n, delta_t=0)
 | 
				
			||||||
 | 
					    negative = get_samples(n=n, delta_t=np.pi)
 | 
				
			||||||
 | 
					    x = np.concatenate(
 | 
				
			||||||
 | 
					        [np.array(positive).reshape(n, -1),
 | 
				
			||||||
 | 
					         np.array(negative).reshape(n, -1)],
 | 
				
			||||||
 | 
					        axis=0)
 | 
				
			||||||
 | 
					    y = np.concatenate([np.zeros(n), np.ones(n)])
 | 
				
			||||||
 | 
					    return x, y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    x_train, y_train = make_spirals(n_samples=1000, noise=0.3)
 | 
				
			||||||
 | 
					    train_ds = NumpyDataset(x_train, y_train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataloaders
 | 
				
			||||||
 | 
					    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        input_dim=x_train.shape[1],
 | 
				
			||||||
 | 
					        nclasses=2,
 | 
				
			||||||
 | 
					        prototypes_per_class=40,
 | 
				
			||||||
 | 
					        prototype_initializer="stratified_random",
 | 
				
			||||||
 | 
					        lr=0.05,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    model_class = CBC
 | 
				
			||||||
 | 
					    model = model_class(hparams, data=[x_train, y_train])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Pure-positive reasonings
 | 
				
			||||||
 | 
					    new_reasoning = torch.zeros_like(
 | 
				
			||||||
 | 
					        model.reasoning_layer.reasoning_probabilities)
 | 
				
			||||||
 | 
					    for i, label in enumerate(model.proto_layer.prototype_labels):
 | 
				
			||||||
 | 
					        new_reasoning[0][0][i][int(label)] = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model.reasoning_layer.reasoning_probabilities.data = new_reasoning
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Model summary
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = VisualizationCallback(x_train,
 | 
				
			||||||
 | 
					                                y_train,
 | 
				
			||||||
 | 
					                                prototype_model=hasattr(model, "prototypes"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
 | 
					        max_epochs=500,
 | 
				
			||||||
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
							
								
								
									
										142
									
								
								examples/cbc_spiral_with_GLVQ_start.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								examples/cbc_spiral_with_GLVQ_start.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,142 @@
 | 
				
			|||||||
 | 
					"""CBC example using the Iris dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
 | 
					from prototorch.models.cbc import CBC
 | 
				
			||||||
 | 
					from prototorch.models.glvq import GLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        x_train,
 | 
				
			||||||
 | 
					        y_train,
 | 
				
			||||||
 | 
					        prototype_model=True,
 | 
				
			||||||
 | 
					        title="Prototype Visualization",
 | 
				
			||||||
 | 
					        cmap="viridis",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.x_train = x_train
 | 
				
			||||||
 | 
					        self.y_train = y_train
 | 
				
			||||||
 | 
					        self.title = title
 | 
				
			||||||
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
 | 
					        self.cmap = cmap
 | 
				
			||||||
 | 
					        self.prototype_model = prototype_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        if self.prototype_model:
 | 
				
			||||||
 | 
					            protos = pl_module.prototypes
 | 
				
			||||||
 | 
					            color = pl_module.prototype_labels
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            protos = pl_module.components
 | 
				
			||||||
 | 
					            color = "k"
 | 
				
			||||||
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
 | 
					        ax.cla()
 | 
				
			||||||
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
 | 
					        ax.set_xlabel("Data dimension 1")
 | 
				
			||||||
 | 
					        ax.set_ylabel("Data dimension 2")
 | 
				
			||||||
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
 | 
					        ax.scatter(
 | 
				
			||||||
 | 
					            protos[:, 0],
 | 
				
			||||||
 | 
					            protos[:, 1],
 | 
				
			||||||
 | 
					            c=color,
 | 
				
			||||||
 | 
					            cmap=self.cmap,
 | 
				
			||||||
 | 
					            edgecolor="k",
 | 
				
			||||||
 | 
					            marker="D",
 | 
				
			||||||
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
 | 
					        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
				
			||||||
 | 
					        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
				
			||||||
 | 
					        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
				
			||||||
 | 
					                             np.arange(y_min, y_max, 1 / 50))
 | 
				
			||||||
 | 
					        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
				
			||||||
 | 
					        y_pred = pl_module.predict(torch.Tensor(mesh_input))
 | 
				
			||||||
 | 
					        y_pred = y_pred.reshape(xx.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
				
			||||||
 | 
					        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
				
			||||||
 | 
					        plt.pause(0.1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def make_spirals(n_samples=500, noise=0.3):
 | 
				
			||||||
 | 
					    def get_samples(n, delta_t):
 | 
				
			||||||
 | 
					        points = []
 | 
				
			||||||
 | 
					        for i in range(n):
 | 
				
			||||||
 | 
					            r = i / n_samples * 5
 | 
				
			||||||
 | 
					            t = 1.75 * i / n * 2 * np.pi + delta_t
 | 
				
			||||||
 | 
					            x = r * np.sin(t) + np.random.rand(1) * noise
 | 
				
			||||||
 | 
					            y = r * np.cos(t) + np.random.rand(1) * noise
 | 
				
			||||||
 | 
					            points.append([x, y])
 | 
				
			||||||
 | 
					        return points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n = n_samples // 2
 | 
				
			||||||
 | 
					    positive = get_samples(n=n, delta_t=0)
 | 
				
			||||||
 | 
					    negative = get_samples(n=n, delta_t=np.pi)
 | 
				
			||||||
 | 
					    x = np.concatenate(
 | 
				
			||||||
 | 
					        [np.array(positive).reshape(n, -1),
 | 
				
			||||||
 | 
					         np.array(negative).reshape(n, -1)],
 | 
				
			||||||
 | 
					        axis=0)
 | 
				
			||||||
 | 
					    y = np.concatenate([np.zeros(n), np.ones(n)])
 | 
				
			||||||
 | 
					    return x, y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def train(model, x_train, y_train, train_loader, epochs=100):
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = VisualizationCallback(x_train,
 | 
				
			||||||
 | 
					                                y_train,
 | 
				
			||||||
 | 
					                                prototype_model=hasattr(model, "prototypes"))
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
 | 
					        max_epochs=epochs,
 | 
				
			||||||
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    x_train, y_train = make_spirals(n_samples=1000, noise=0.3)
 | 
				
			||||||
 | 
					    train_ds = NumpyDataset(x_train, y_train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataloaders
 | 
				
			||||||
 | 
					    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        input_dim=x_train.shape[1],
 | 
				
			||||||
 | 
					        nclasses=2,
 | 
				
			||||||
 | 
					        prototypes_per_class=40,
 | 
				
			||||||
 | 
					        prototype_initializer="stratified_random",
 | 
				
			||||||
 | 
					        lr=0.05,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    glvq_model = GLVQ(hparams, data=[x_train, y_train])
 | 
				
			||||||
 | 
					    cbc_model = CBC(hparams, data=[x_train, y_train])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Train GLVQ
 | 
				
			||||||
 | 
					    train(glvq_model, x_train, y_train, train_loader, epochs=10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Transfer Prototypes
 | 
				
			||||||
 | 
					    cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict())
 | 
				
			||||||
 | 
					    # Pure-positive reasonings
 | 
				
			||||||
 | 
					    new_reasoning = torch.zeros_like(
 | 
				
			||||||
 | 
					        cbc_model.reasoning_layer.reasoning_probabilities)
 | 
				
			||||||
 | 
					    for i, label in enumerate(cbc_model.proto_layer.prototype_labels):
 | 
				
			||||||
 | 
					        new_reasoning[0][0][i][int(label)] = 1.0
 | 
				
			||||||
 | 
					        new_reasoning[1][0][i][1 - int(label)] = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cbc_model.reasoning_layer.reasoning_probabilities.data = new_reasoning
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Train CBC
 | 
				
			||||||
 | 
					    train(cbc_model, x_train, y_train, train_loader, epochs=50)
 | 
				
			||||||
@@ -6,15 +6,11 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from torch.utils.data import DataLoader, TensorDataset
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
class NumpyDataset(TensorDataset):
 | 
					from prototorch.models.glvq import GLVQ
 | 
				
			||||||
    def __init__(self, *arrays):
 | 
					 | 
				
			||||||
        tensors = [torch.from_numpy(arr) for arr in arrays]
 | 
					 | 
				
			||||||
        super().__init__(*tensors)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GLVQIris(GLVQ):
 | 
					class GLVQIris(GLVQ):
 | 
				
			||||||
@@ -56,13 +52,15 @@ class VisualizationCallback(pl.Callback):
 | 
				
			|||||||
        ax.set_xlabel("Data dimension 1")
 | 
					        ax.set_xlabel("Data dimension 1")
 | 
				
			||||||
        ax.set_ylabel("Data dimension 2")
 | 
					        ax.set_ylabel("Data dimension 2")
 | 
				
			||||||
        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
        ax.scatter(protos[:, 0],
 | 
					        ax.scatter(
 | 
				
			||||||
                   protos[:, 1],
 | 
					            protos[:, 0],
 | 
				
			||||||
                   c=plabels,
 | 
					            protos[:, 1],
 | 
				
			||||||
                   cmap=self.cmap,
 | 
					            c=plabels,
 | 
				
			||||||
                   edgecolor="k",
 | 
					            cmap=self.cmap,
 | 
				
			||||||
                   marker="D",
 | 
					            edgecolor="k",
 | 
				
			||||||
                   s=50)
 | 
					            marker="D",
 | 
				
			||||||
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
					        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
					        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
				
			||||||
@@ -105,8 +103,8 @@ if __name__ == "__main__":
 | 
				
			|||||||
        parser,
 | 
					        parser,
 | 
				
			||||||
        max_epochs=10,
 | 
					        max_epochs=10,
 | 
				
			||||||
        callbacks=[
 | 
					        callbacks=[
 | 
				
			||||||
            vis,  # comment this line out to disable the visualization
 | 
					            vis,
 | 
				
			||||||
        ],
 | 
					        ],  # comment this line out to disable the visualization
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    # trainer.tune(model)
 | 
					    # trainer.tune(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,15 +4,11 @@ import numpy as np
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from prototorch.models.glvq import GLVQ
 | 
					 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from torch.utils.data import DataLoader, TensorDataset
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.datasets.abstract import NumpyDataset
 | 
				
			||||||
class NumpyDataset(TensorDataset):
 | 
					from prototorch.models.glvq import GLVQ
 | 
				
			||||||
    def __init__(self, *arrays):
 | 
					 | 
				
			||||||
        tensors = [torch.from_numpy(arr) for arr in arrays]
 | 
					 | 
				
			||||||
        super().__init__(*tensors)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
@@ -37,13 +33,15 @@ class VisualizationCallback(pl.Callback):
 | 
				
			|||||||
        ax.set_xlabel("Data dimension 1")
 | 
					        ax.set_xlabel("Data dimension 1")
 | 
				
			||||||
        ax.set_ylabel("Data dimension 2")
 | 
					        ax.set_ylabel("Data dimension 2")
 | 
				
			||||||
        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
        ax.scatter(protos[:, 0],
 | 
					        ax.scatter(
 | 
				
			||||||
                   protos[:, 1],
 | 
					            protos[:, 0],
 | 
				
			||||||
                   c=plabels,
 | 
					            protos[:, 1],
 | 
				
			||||||
                   cmap=self.cmap,
 | 
					            c=plabels,
 | 
				
			||||||
                   edgecolor="k",
 | 
					            cmap=self.cmap,
 | 
				
			||||||
                   marker="D",
 | 
					            edgecolor="k",
 | 
				
			||||||
                   s=50)
 | 
					            marker="D",
 | 
				
			||||||
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
					        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
					        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
				
			||||||
@@ -69,11 +67,13 @@ if __name__ == "__main__":
 | 
				
			|||||||
    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
					    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(input_dim=x_train.shape[1],
 | 
					    hparams = dict(
 | 
				
			||||||
                   nclasses=3,
 | 
					        input_dim=x_train.shape[1],
 | 
				
			||||||
                   prototypes_per_class=3,
 | 
					        nclasses=3,
 | 
				
			||||||
                   prototype_initializer="stratified_mean",
 | 
					        prototypes_per_class=3,
 | 
				
			||||||
                   lr=0.1)
 | 
					        prototype_initializer="stratified_mean",
 | 
				
			||||||
 | 
					        lr=0.1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = GLVQ(hparams, data=[x_train, y_train])
 | 
					    model = GLVQ(hparams, data=[x_train, y_train])
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,13 +11,12 @@ import argparse
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torchvision
 | 
					import torchvision
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					 | 
				
			||||||
from prototorch.functions.initializers import stratified_mean
 | 
					 | 
				
			||||||
from prototorch.models.glvq import ImageGLVQ
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
from torchvision.datasets import MNIST
 | 
					from torchvision.datasets import MNIST
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.models.glvq import ImageGLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisualizationCallback(pl.Callback):
 | 
					class VisualizationCallback(pl.Callback):
 | 
				
			||||||
    def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
 | 
					    def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
 | 
				
			||||||
@@ -31,10 +30,12 @@ class VisualizationCallback(pl.Callback):
 | 
				
			|||||||
        grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow)
 | 
					        grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow)
 | 
				
			||||||
        # grid = grid.permute((1, 2, 0))
 | 
					        # grid = grid.permute((1, 2, 0))
 | 
				
			||||||
        tb = pl_module.logger.experiment
 | 
					        tb = pl_module.logger.experiment
 | 
				
			||||||
        tb.add_image(tag="MNIST Prototypes",
 | 
					        tb.add_image(
 | 
				
			||||||
                     img_tensor=grid,
 | 
					            tag="MNIST Prototypes",
 | 
				
			||||||
                     global_step=trainer.current_epoch,
 | 
					            img_tensor=grid,
 | 
				
			||||||
                     dataformats="CHW")
 | 
					            global_step=trainer.current_epoch,
 | 
				
			||||||
 | 
					            dataformats="CHW",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
@@ -91,11 +92,13 @@ if __name__ == "__main__":
 | 
				
			|||||||
    x = x.view(len(mnist_train), -1)
 | 
					    x = x.view(len(mnist_train), -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = ImageGLVQ(input_dim=28 * 28,
 | 
					    model = ImageGLVQ(
 | 
				
			||||||
                      nclasses=10,
 | 
					        input_dim=28 * 28,
 | 
				
			||||||
                      prototypes_per_class=args.ppc,
 | 
					        nclasses=10,
 | 
				
			||||||
                      prototype_initializer="stratified_mean",
 | 
					        prototypes_per_class=args.ppc,
 | 
				
			||||||
                      data=[x, y])
 | 
					        prototype_initializer="stratified_mean",
 | 
				
			||||||
 | 
					        data=[x, y],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    # Model summary
 | 
					    # Model summary
 | 
				
			||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,8 +1,8 @@
 | 
				
			|||||||
from importlib.metadata import version, PackageNotFoundError
 | 
					from importlib.metadata import PackageNotFoundError, version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
VERSION_FALLBACK = "uninstalled_version"
 | 
					VERSION_FALLBACK = "uninstalled_version"
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    __version__ = version(__name__.replace(".", "-"))
 | 
					    __version__ = version(__name__.replace(".", "-"))
 | 
				
			||||||
except PackageNotFoundError:
 | 
					except PackageNotFoundError:
 | 
				
			||||||
    __version__ = VERSION_FALLBACK
 | 
					    __version__ = VERSION_FALLBACK
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,13 +1,9 @@
 | 
				
			|||||||
import argparse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					from prototorch.functions.distances import euclidean_distance
 | 
				
			||||||
from prototorch.functions.similarities import cosine_similarity
 | 
					from prototorch.functions.similarities import cosine_similarity
 | 
				
			||||||
from prototorch.functions.initializers import get_initializer
 | 
					 | 
				
			||||||
from prototorch.functions.losses import glvq_loss
 | 
					 | 
				
			||||||
from prototorch.modules.prototypes import Prototypes1D
 | 
					from prototorch.modules.prototypes import Prototypes1D
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -64,9 +60,6 @@ class ReasoningLayer(torch.nn.Module):
 | 
				
			|||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.n_replicas = n_replicas
 | 
					        self.n_replicas = n_replicas
 | 
				
			||||||
        self.n_classes = n_classes
 | 
					        self.n_classes = n_classes
 | 
				
			||||||
        # probabilities_init = torch.zeros(2, self.n_replicas, n_components,
 | 
					 | 
				
			||||||
        #                                  self.n_classes)
 | 
					 | 
				
			||||||
        # probabilities_init = torch.zeros(2, n_components, self.n_classes)
 | 
					 | 
				
			||||||
        probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
 | 
					        probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
 | 
				
			||||||
        probabilities_init.uniform_(0.4, 0.6)
 | 
					        probabilities_init.uniform_(0.4, 0.6)
 | 
				
			||||||
        self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
 | 
					        self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
 | 
				
			||||||
@@ -75,37 +68,28 @@ class ReasoningLayer(torch.nn.Module):
 | 
				
			|||||||
    def reasonings(self):
 | 
					    def reasonings(self):
 | 
				
			||||||
        pk = self.reasoning_probabilities[0]
 | 
					        pk = self.reasoning_probabilities[0]
 | 
				
			||||||
        nk = (1 - pk) * self.reasoning_probabilities[1]
 | 
					        nk = (1 - pk) * self.reasoning_probabilities[1]
 | 
				
			||||||
        ik = (1 - pk - nk)
 | 
					        ik = 1 - pk - nk
 | 
				
			||||||
        # pk is of shape (1, n_components, n_classes)
 | 
					 | 
				
			||||||
        img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
 | 
					        img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
 | 
				
			||||||
        return img.unsqueeze(1)  # (n_components, 1, 3, n_classes)
 | 
					        return img.unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, detections):
 | 
					    def forward(self, detections):
 | 
				
			||||||
        pk = self.reasoning_probabilities[0].clamp(0, 1)
 | 
					        pk = self.reasoning_probabilities[0].clamp(0, 1)
 | 
				
			||||||
        nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
 | 
					        nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
 | 
				
			||||||
        epsilon = torch.finfo(pk.dtype).eps
 | 
					        epsilon = torch.finfo(pk.dtype).eps
 | 
				
			||||||
        # print(f"{detections.shape=}")
 | 
					 | 
				
			||||||
        # print(f"{pk.shape=}")
 | 
					 | 
				
			||||||
        # print(f"{detections.min()=}")
 | 
					 | 
				
			||||||
        # print(f"{detections.max()=}")
 | 
					 | 
				
			||||||
        numerator = (detections @ (pk - nk)) + nk.sum(1)
 | 
					        numerator = (detections @ (pk - nk)) + nk.sum(1)
 | 
				
			||||||
        # probs = numerator / (pk + nk).sum(1).clamp(min=epsilon)
 | 
					 | 
				
			||||||
        probs = numerator / (pk + nk).sum(1)
 | 
					        probs = numerator / (pk + nk).sum(1)
 | 
				
			||||||
        # probs = probs.squeeze(0)
 | 
					 | 
				
			||||||
        probs = probs.squeeze(0)
 | 
					        probs = probs.squeeze(0)
 | 
				
			||||||
        return probs
 | 
					        return probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CBC(pl.LightningModule):
 | 
					class CBC(pl.LightningModule):
 | 
				
			||||||
    """Classification-By-Components."""
 | 
					    """Classification-By-Components."""
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(self,
 | 
				
			||||||
            self,
 | 
					                 hparams,
 | 
				
			||||||
            hparams,
 | 
					                 margin=0.1,
 | 
				
			||||||
            margin=0.1,
 | 
					                 backbone_class=torch.nn.Identity,
 | 
				
			||||||
            backbone_class=torch.nn.Identity,
 | 
					                 similarity=euclidean_similarity,
 | 
				
			||||||
            # similarity=rescaled_cosine_similarity,
 | 
					                 **kwargs):
 | 
				
			||||||
            similarity=euclidean_similarity,
 | 
					 | 
				
			||||||
            **kwargs):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.save_hyperparameters(hparams)
 | 
					        self.save_hyperparameters(hparams)
 | 
				
			||||||
        self.margin = margin
 | 
					        self.margin = margin
 | 
				
			||||||
@@ -142,15 +126,11 @@ class CBC(pl.LightningModule):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        self.sync_backbones()
 | 
					        self.sync_backbones()
 | 
				
			||||||
        protos = self.proto_layer.prototypes
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
        # protos, _ = self.proto_layer()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        latent_protos = self.backbone_dependent(protos)
 | 
					        latent_protos = self.backbone_dependent(protos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # print(f"{latent_x.dtype=}")
 | 
					 | 
				
			||||||
        # print(f"{latent_protos.dtype=}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        detections = self.similarity(latent_x, latent_protos)
 | 
					        detections = self.similarity(latent_x, latent_protos)
 | 
				
			||||||
        probs = self.reasoning_layer(detections)
 | 
					        probs = self.reasoning_layer(detections)
 | 
				
			||||||
        return probs
 | 
					        return probs
 | 
				
			||||||
@@ -159,20 +139,10 @@ class CBC(pl.LightningModule):
 | 
				
			|||||||
        x, y = train_batch
 | 
					        x, y = train_batch
 | 
				
			||||||
        x = x.view(x.size(0), -1)
 | 
					        x = x.view(x.size(0), -1)
 | 
				
			||||||
        y_pred = self(x)
 | 
					        y_pred = self(x)
 | 
				
			||||||
        # print(f"{y_pred.min()=}")
 | 
					 | 
				
			||||||
        # print(f"{y_pred.max()=}")
 | 
					 | 
				
			||||||
        nclasses = self.reasoning_layer.n_classes
 | 
					        nclasses = self.reasoning_layer.n_classes
 | 
				
			||||||
        # y_true = torch.nn.functional.one_hot(y, num_classes=nclasses)
 | 
					 | 
				
			||||||
        # y_true = torch.eye(nclasses)[y.long()]
 | 
					 | 
				
			||||||
        y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
 | 
					        y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
 | 
				
			||||||
        loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
 | 
					        loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
 | 
				
			||||||
        self.log("train_loss", loss)
 | 
					        self.log("train_loss", loss)
 | 
				
			||||||
        # with torch.no_grad():
 | 
					 | 
				
			||||||
        #     preds = torch.argmax(y_pred, dim=1)
 | 
					 | 
				
			||||||
        # # self.train_acc.update(preds.int(), y.int())
 | 
					 | 
				
			||||||
        # self.train_acc(
 | 
					 | 
				
			||||||
        #     preds.int(),
 | 
					 | 
				
			||||||
        #     y.int())  # FloatTensors are assumed to be class probabilities
 | 
					 | 
				
			||||||
        self.train_acc(y_pred, y_true)
 | 
					        self.train_acc(y_pred, y_true)
 | 
				
			||||||
        self.log(
 | 
					        self.log(
 | 
				
			||||||
            "acc",
 | 
					            "acc",
 | 
				
			||||||
@@ -184,17 +154,8 @@ class CBC(pl.LightningModule):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        return loss
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
					 | 
				
			||||||
    #    self.reasoning_layer.reasoning_probabilities.data.clamp_(0., 1.)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # def training_epoch_end(self, outs):
 | 
					 | 
				
			||||||
    #     # Calling `self.train_acc.compute()` is
 | 
					 | 
				
			||||||
    #     # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
 | 
					 | 
				
			||||||
    #     self.log("train_acc_epoch", self.train_acc.compute())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def predict(self, x):
 | 
					    def predict(self, x):
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            # model.eval()  # ?!
 | 
					 | 
				
			||||||
            y_pred = self(x)
 | 
					            y_pred = self(x)
 | 
				
			||||||
            y_pred = torch.argmax(y_pred, dim=1)
 | 
					            y_pred = torch.argmax(y_pred, dim=1)
 | 
				
			||||||
        return y_pred.numpy()
 | 
					        return y_pred.numpy()
 | 
				
			||||||
@@ -205,5 +166,5 @@ class ImageCBC(CBC):
 | 
				
			|||||||
    clamping after updates.
 | 
					    clamping after updates.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
					    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
				
			||||||
        #super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
 | 
					        # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
 | 
				
			||||||
        self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
 | 
					        self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,11 +1,9 @@
 | 
				
			|||||||
import argparse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					from prototorch.functions.competitions import wtac
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					from prototorch.functions.distances import euclidean_distance
 | 
				
			||||||
from prototorch.functions.initializers import get_initializer
 | 
					 | 
				
			||||||
from prototorch.functions.losses import glvq_loss
 | 
					from prototorch.functions.losses import glvq_loss
 | 
				
			||||||
from prototorch.modules.prototypes import Prototypes1D
 | 
					from prototorch.modules.prototypes import Prototypes1D
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -54,12 +52,14 @@ class GLVQ(pl.LightningModule):
 | 
				
			|||||||
        self.train_acc(
 | 
					        self.train_acc(
 | 
				
			||||||
            preds.int(),
 | 
					            preds.int(),
 | 
				
			||||||
            y.int())  # FloatTensors are assumed to be class probabilities
 | 
					            y.int())  # FloatTensors are assumed to be class probabilities
 | 
				
			||||||
        self.log("acc",
 | 
					        self.log(
 | 
				
			||||||
                 self.train_acc,
 | 
					            "acc",
 | 
				
			||||||
                 on_step=False,
 | 
					            self.train_acc,
 | 
				
			||||||
                 on_epoch=True,
 | 
					            on_step=False,
 | 
				
			||||||
                 prog_bar=True,
 | 
					            on_epoch=True,
 | 
				
			||||||
                 logger=True)
 | 
					            prog_bar=True,
 | 
				
			||||||
 | 
					            logger=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        return loss
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # def training_epoch_end(self, outs):
 | 
					    # def training_epoch_end(self, outs):
 | 
				
			||||||
@@ -81,4 +81,4 @@ class ImageGLVQ(GLVQ):
 | 
				
			|||||||
    clamping after updates.
 | 
					    clamping after updates.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
					    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
				
			||||||
        self.proto_layer.prototypes.data.clamp_(0., 1.)
 | 
					        self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										6
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								setup.py
									
									
									
									
									
								
							@@ -9,8 +9,7 @@
 | 
				
			|||||||
ProtoTorch models Plugin Package
 | 
					ProtoTorch models Plugin Package
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from pkg_resources import safe_name
 | 
					from pkg_resources import safe_name
 | 
				
			||||||
from setuptools import setup
 | 
					from setuptools import find_namespace_packages, setup
 | 
				
			||||||
from setuptools import find_namespace_packages
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
PLUGIN_NAME = "models"
 | 
					PLUGIN_NAME = "models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -28,7 +27,8 @@ ALL = EXAMPLES + TESTS
 | 
				
			|||||||
setup(
 | 
					setup(
 | 
				
			||||||
    name=safe_name("prototorch_" + PLUGIN_NAME),
 | 
					    name=safe_name("prototorch_" + PLUGIN_NAME),
 | 
				
			||||||
    use_scm_version=True,
 | 
					    use_scm_version=True,
 | 
				
			||||||
    descripion="Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning.",
 | 
					    descripion=
 | 
				
			||||||
 | 
					    "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning.",
 | 
				
			||||||
    long_description=long_description,
 | 
					    long_description=long_description,
 | 
				
			||||||
    author="Alexander Engelsberger",
 | 
					    author="Alexander Engelsberger",
 | 
				
			||||||
    author_email="engelsbe@hs-mittweida.de",
 | 
					    author_email="engelsbe@hs-mittweida.de",
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user