Add partial cbc implementation
This commit is contained in:
		
							
								
								
									
										116
									
								
								examples/cbc_iris.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								examples/cbc_iris.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,116 @@
 | 
			
		||||
"""CBC example using the Iris dataset."""
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from matplotlib import pyplot as plt
 | 
			
		||||
from prototorch.models.cbc import CBC
 | 
			
		||||
from sklearn.datasets import load_iris
 | 
			
		||||
from torch.utils.data import DataLoader, TensorDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumpyDataset(TensorDataset):
 | 
			
		||||
    def __init__(self, *arrays):
 | 
			
		||||
        # tensors = [torch.from_numpy(arr) for arr in arrays]
 | 
			
		||||
        tensors = [torch.Tensor(arr) for arr in arrays]
 | 
			
		||||
        super().__init__(*tensors)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VisualizationCallback(pl.Callback):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 x_train,
 | 
			
		||||
                 y_train,
 | 
			
		||||
                 title="Prototype Visualization",
 | 
			
		||||
                 cmap="viridis"):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.x_train = x_train
 | 
			
		||||
        self.y_train = y_train
 | 
			
		||||
        self.title = title
 | 
			
		||||
        self.fig = plt.figure(self.title)
 | 
			
		||||
        self.cmap = cmap
 | 
			
		||||
 | 
			
		||||
    def on_epoch_end(self, trainer, pl_module):
 | 
			
		||||
        # protos = pl_module.prototypes
 | 
			
		||||
        protos = pl_module.components
 | 
			
		||||
        # plabels = pl_module.prototype_labels
 | 
			
		||||
        ax = self.fig.gca()
 | 
			
		||||
        ax.cla()
 | 
			
		||||
        ax.set_title(self.title)
 | 
			
		||||
        ax.set_xlabel("Data dimension 1")
 | 
			
		||||
        ax.set_ylabel("Data dimension 2")
 | 
			
		||||
        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
			
		||||
        ax.scatter(
 | 
			
		||||
            protos[:, 0],
 | 
			
		||||
            protos[:, 1],
 | 
			
		||||
            # c=plabels,
 | 
			
		||||
            c="k",
 | 
			
		||||
            cmap=self.cmap,
 | 
			
		||||
            edgecolor="k",
 | 
			
		||||
            marker="D",
 | 
			
		||||
            s=50)
 | 
			
		||||
        x = np.vstack((x_train, protos))
 | 
			
		||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
			
		||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
			
		||||
        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
			
		||||
                             np.arange(y_min, y_max, 1 / 50))
 | 
			
		||||
        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
			
		||||
        y_pred = pl_module.predict(torch.Tensor(mesh_input))
 | 
			
		||||
        y_pred = y_pred.reshape(xx.shape)
 | 
			
		||||
 | 
			
		||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
			
		||||
        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
			
		||||
        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
			
		||||
        plt.pause(0.1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Dataset
 | 
			
		||||
    x_train, y_train = load_iris(return_X_y=True)
 | 
			
		||||
    x_train = x_train[:, [0, 2]]
 | 
			
		||||
    train_ds = NumpyDataset(x_train, y_train)
 | 
			
		||||
 | 
			
		||||
    # Dataloaders
 | 
			
		||||
    train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
 | 
			
		||||
 | 
			
		||||
    # Hyperparameters
 | 
			
		||||
    hparams = dict(input_dim=x_train.shape[1],
 | 
			
		||||
                   nclasses=3,
 | 
			
		||||
                   prototypes_per_class=3,
 | 
			
		||||
                   prototype_initializer="stratified_mean",
 | 
			
		||||
                   lr=0.01)
 | 
			
		||||
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = CBC(hparams, data=[x_train, y_train])
 | 
			
		||||
 | 
			
		||||
    # Fix the component locations
 | 
			
		||||
    # model.proto_layer.requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
    # Pure-positive reasonings
 | 
			
		||||
    ncomps = 3
 | 
			
		||||
    nclasses = 3
 | 
			
		||||
    rmat = torch.stack(
 | 
			
		||||
        [0.9 * torch.eye(ncomps),
 | 
			
		||||
         torch.zeros(ncomps, nclasses)], dim=0)
 | 
			
		||||
    # model.reasoning_layer.load_state_dict({"reasoning_probabilities": rmat},
 | 
			
		||||
    #                                       strict=True)
 | 
			
		||||
 | 
			
		||||
    print(model.reasoning_layer.reasoning_probabilities)
 | 
			
		||||
    # import sys
 | 
			
		||||
    # sys.exit()
 | 
			
		||||
 | 
			
		||||
    # Model summary
 | 
			
		||||
    print(model)
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    vis = VisualizationCallback(x_train, y_train)
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer(
 | 
			
		||||
        max_epochs=100,
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
    trainer.fit(model, train_loader)
 | 
			
		||||
							
								
								
									
										205
									
								
								prototorch/models/cbc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										205
									
								
								prototorch/models/cbc.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,205 @@
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
import torchmetrics
 | 
			
		||||
from prototorch.functions.competitions import wtac
 | 
			
		||||
from prototorch.functions.distances import euclidean_distance
 | 
			
		||||
from prototorch.functions.similarities import cosine_similarity
 | 
			
		||||
from prototorch.functions.initializers import get_initializer
 | 
			
		||||
from prototorch.functions.losses import glvq_loss
 | 
			
		||||
from prototorch.modules.prototypes import Prototypes1D
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rescaled_cosine_similarity(x, y):
 | 
			
		||||
    """Cosine Similarity rescaled to [0, 1]."""
 | 
			
		||||
    similarities = cosine_similarity(x, y)
 | 
			
		||||
    return (similarities + 1.0) / 2.0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def shift_activation(x):
 | 
			
		||||
    return (x + 1.0) / 2.0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def euclidean_similarity(x, y):
 | 
			
		||||
    d = euclidean_distance(x, y)
 | 
			
		||||
    return torch.exp(-d * 3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CosineSimilarity(torch.nn.Module):
 | 
			
		||||
    def __init__(self, activation=shift_activation):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.activation = activation
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, y):
 | 
			
		||||
        epsilon = torch.finfo(x.dtype).eps
 | 
			
		||||
        normed_x = (x/ x.pow(2) \
 | 
			
		||||
            .sum(dim=tuple(range(1, x.ndim)), keepdim=True) \
 | 
			
		||||
            .clamp(min=epsilon) \
 | 
			
		||||
            .sqrt()).flatten(start_dim=1)
 | 
			
		||||
        normed_y = (y / y.pow(2) \
 | 
			
		||||
            .sum(dim=tuple(range(1, y.ndim)), keepdim=True) \
 | 
			
		||||
            .clamp(min=epsilon) \
 | 
			
		||||
            .sqrt()).flatten(start_dim=1)
 | 
			
		||||
        # normed_x = (x / torch.linalg.norm(x, dim=1))
 | 
			
		||||
        diss = torch.inner(normed_x, normed_y)
 | 
			
		||||
        return self.activation(diss)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MarginLoss(torch.nn.modules.loss._Loss):
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 margin=0.3,
 | 
			
		||||
                 size_average=None,
 | 
			
		||||
                 reduce=None,
 | 
			
		||||
                 reduction="mean"):
 | 
			
		||||
        super().__init__(size_average, reduce, reduction)
 | 
			
		||||
        self.margin = margin
 | 
			
		||||
 | 
			
		||||
    def forward(self, input_, target):
 | 
			
		||||
        dp = torch.sum(target * input_, dim=-1)
 | 
			
		||||
        dm = torch.max(input_ - target, dim=-1).values
 | 
			
		||||
        return torch.nn.functional.relu(dm - dp + self.margin)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReasoningLayer(torch.nn.Module):
 | 
			
		||||
    def __init__(self, n_components, n_classes, n_replicas=1):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.n_replicas = n_replicas
 | 
			
		||||
        self.n_classes = n_classes
 | 
			
		||||
        # probabilities_init = torch.zeros(2, self.n_replicas, n_components,
 | 
			
		||||
        #                                  self.n_classes)
 | 
			
		||||
        # probabilities_init = torch.zeros(2, n_components, self.n_classes)
 | 
			
		||||
        probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
 | 
			
		||||
        probabilities_init.uniform_(0.4, 0.6)
 | 
			
		||||
        self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
 | 
			
		||||
 | 
			
		||||
    # @property
 | 
			
		||||
    # def reasonings(self):
 | 
			
		||||
    #     pk = self.reasoning_probabilities[0]
 | 
			
		||||
    #     nk = (1 - pk) * self.reasoning_probabilities[1]
 | 
			
		||||
    #     ik = (1 - pk - nk)
 | 
			
		||||
    #     # pk is of shape (1, n_components, n_classes)
 | 
			
		||||
    #     img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
 | 
			
		||||
    #     return img.unsqueeze(1)  # (n_components, 1, 3, n_classes)
 | 
			
		||||
 | 
			
		||||
    def forward(self, detections):
 | 
			
		||||
        pk = self.reasoning_probabilities[0].clamp(0, 1)
 | 
			
		||||
        nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
 | 
			
		||||
        epsilon = torch.finfo(pk.dtype).eps
 | 
			
		||||
        # print(f"{detections.shape=}")
 | 
			
		||||
        # print(f"{pk.shape=}")
 | 
			
		||||
        # print(f"{detections.min()=}")
 | 
			
		||||
        # print(f"{detections.max()=}")
 | 
			
		||||
        numerator = (detections @ (pk - nk)) + nk.sum(1)
 | 
			
		||||
        # probs = numerator / (pk + nk).sum(1).clamp(min=epsilon)
 | 
			
		||||
        probs = numerator / (pk + nk).sum(1)
 | 
			
		||||
        # probs = probs.squeeze(0)
 | 
			
		||||
        probs = probs.squeeze(0)
 | 
			
		||||
        return probs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CBC(pl.LightningModule):
 | 
			
		||||
    """Classification-By-Components."""
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self,
 | 
			
		||||
            hparams,
 | 
			
		||||
            margin=0.1,
 | 
			
		||||
            backbone_class=torch.nn.Identity,
 | 
			
		||||
            # similarity=rescaled_cosine_similarity,
 | 
			
		||||
            similarity=euclidean_similarity,
 | 
			
		||||
            **kwargs):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.save_hyperparameters(hparams)
 | 
			
		||||
        self.margin = margin
 | 
			
		||||
        self.proto_layer = Prototypes1D(
 | 
			
		||||
            input_dim=self.hparams.input_dim,
 | 
			
		||||
            nclasses=self.hparams.nclasses,
 | 
			
		||||
            prototypes_per_class=self.hparams.prototypes_per_class,
 | 
			
		||||
            prototype_initializer=self.hparams.prototype_initializer,
 | 
			
		||||
            **kwargs)
 | 
			
		||||
        # self.similarity = CosineSimilarity()
 | 
			
		||||
        self.similarity = similarity
 | 
			
		||||
        self.backbone = backbone_class()
 | 
			
		||||
        self.backbone_dependent = backbone_class().requires_grad_(False)
 | 
			
		||||
        n_components = self.components.shape[0]
 | 
			
		||||
        self.reasoning_layer = ReasoningLayer(n_components=n_components,
 | 
			
		||||
                                              n_classes=self.hparams.nclasses)
 | 
			
		||||
        self.train_acc = torchmetrics.Accuracy()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def components(self):
 | 
			
		||||
        return self.proto_layer.prototypes.detach().numpy()
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
 | 
			
		||||
        return optimizer
 | 
			
		||||
 | 
			
		||||
    def sync_backbones(self):
 | 
			
		||||
        master_state = self.backbone.state_dict()
 | 
			
		||||
        self.backbone_dependent.load_state_dict(master_state, strict=True)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        self.sync_backbones()
 | 
			
		||||
        # protos = self.proto_layer.prototypes
 | 
			
		||||
        protos, _ = self.proto_layer()
 | 
			
		||||
 | 
			
		||||
        latent_x = self.backbone(x)
 | 
			
		||||
        latent_protos = self.backbone_dependent(protos)
 | 
			
		||||
 | 
			
		||||
        # print(f"{latent_x.dtype=}")
 | 
			
		||||
        # print(f"{latent_protos.dtype=}")
 | 
			
		||||
 | 
			
		||||
        detections = self.similarity(latent_x, latent_protos)
 | 
			
		||||
        probs = self.reasoning_layer(detections)
 | 
			
		||||
        return probs
 | 
			
		||||
 | 
			
		||||
    def training_step(self, train_batch, batch_idx):
 | 
			
		||||
        x, y = train_batch
 | 
			
		||||
        x = x.view(x.size(0), -1)
 | 
			
		||||
        y_pred = self(x)
 | 
			
		||||
        # print(f"{y_pred.min()=}")
 | 
			
		||||
        # print(f"{y_pred.max()=}")
 | 
			
		||||
        nclasses = self.reasoning_layer.n_classes
 | 
			
		||||
        # y_true = torch.nn.functional.one_hot(y, num_classes=nclasses)
 | 
			
		||||
        # y_true = torch.eye(nclasses)[y.long()]
 | 
			
		||||
        y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
 | 
			
		||||
        loss = MarginLoss(self.margin)(y_pred, y_true).sum(dim=0)
 | 
			
		||||
        self.log("train_loss", loss)
 | 
			
		||||
        # with torch.no_grad():
 | 
			
		||||
        #     preds = torch.argmax(y_pred, dim=1)
 | 
			
		||||
        # # self.train_acc.update(preds.int(), y.int())
 | 
			
		||||
        # self.train_acc(
 | 
			
		||||
        #     preds.int(),
 | 
			
		||||
        #     y.int())  # FloatTensors are assumed to be class probabilities
 | 
			
		||||
        self.train_acc(y_pred, y_true)
 | 
			
		||||
        self.log("acc",
 | 
			
		||||
                 self.train_acc,
 | 
			
		||||
                 on_step=False,
 | 
			
		||||
                 on_epoch=True,
 | 
			
		||||
                 prog_bar=True,
 | 
			
		||||
                 logger=True)
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    # def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
			
		||||
    #     self.reasoning_layer.reasoning_probabilities.data.clamp_(0., 1.)
 | 
			
		||||
 | 
			
		||||
    # def training_epoch_end(self, outs):
 | 
			
		||||
    #     # Calling `self.train_acc.compute()` is
 | 
			
		||||
    #     # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
 | 
			
		||||
    #     self.log("train_acc_epoch", self.train_acc.compute())
 | 
			
		||||
 | 
			
		||||
    def predict(self, x):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            # model.eval()  # ?!
 | 
			
		||||
            y_pred = self(x)
 | 
			
		||||
            y_pred = torch.argmax(y_pred, dim=1)
 | 
			
		||||
        return y_pred.numpy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageCBC(CBC):
 | 
			
		||||
    """CBC model that constrains the components to the range [0, 1] by
 | 
			
		||||
    clamping after updates.
 | 
			
		||||
    """
 | 
			
		||||
    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
			
		||||
        super().on_train_batch_end(outputs, batch, batch_idx, dataload_idx)
 | 
			
		||||
        self.proto_layer.prototypes.data.clamp_(0., 1.)
 | 
			
		||||
		Reference in New Issue
	
	Block a user