[BUGFIX] examples/cbc_iris.py works again
				
					
				
			This commit is contained in:
		@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
@@ -24,14 +23,18 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution=[2, 2, 2],
 | 
					        distribution=[1, 0, 3],
 | 
				
			||||||
        proto_lr=0.1,
 | 
					        margin=0.1,
 | 
				
			||||||
 | 
					        proto_lr=0.01,
 | 
				
			||||||
 | 
					        bb_lr=0.01,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.CBC(
 | 
					    model = pt.models.CBC(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
 | 
					        components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
 | 
				
			||||||
 | 
					        reasonings_iniitializer=pt.initializers.
 | 
				
			||||||
 | 
					        PurePositiveReasoningsInitializer(),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -128,7 +128,7 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def num_classes(self):
 | 
					    def num_classes(self):
 | 
				
			||||||
        return len(self.proto_layer.distribution)
 | 
					        return self.proto_layer.num_classes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def compute_distances(self, x):
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
        protos, _ = self.proto_layer()
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,55 +1,54 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.competitions import CBCC
 | 
				
			||||||
from ..core.components import ReasoningComponents
 | 
					from ..core.components import ReasoningComponents
 | 
				
			||||||
 | 
					from ..core.initializers import RandomReasoningsInitializer
 | 
				
			||||||
 | 
					from ..core.losses import MarginLoss
 | 
				
			||||||
 | 
					from ..core.similarities import euclidean_similarity
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
from .abstract import ImagePrototypesMixin
 | 
					from .abstract import ImagePrototypesMixin
 | 
				
			||||||
from .extras import (
 | 
					 | 
				
			||||||
    CosineSimilarity,
 | 
					 | 
				
			||||||
    MarginLoss,
 | 
					 | 
				
			||||||
    ReasoningLayer,
 | 
					 | 
				
			||||||
    euclidean_similarity,
 | 
					 | 
				
			||||||
    rescaled_cosine_similarity,
 | 
					 | 
				
			||||||
    shift_activation,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from .glvq import SiameseGLVQ
 | 
					from .glvq import SiameseGLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CBC(SiameseGLVQ):
 | 
					class CBC(SiameseGLVQ):
 | 
				
			||||||
    """Classification-By-Components."""
 | 
					    """Classification-By-Components."""
 | 
				
			||||||
    def __init__(self, hparams, margin=0.1, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
        self.margin = margin
 | 
					 | 
				
			||||||
        self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
 | 
					 | 
				
			||||||
        num_components = self.components.shape[0]
 | 
					 | 
				
			||||||
        self.reasoning_layer = ReasoningLayer(num_components=num_components,
 | 
					 | 
				
			||||||
                                              num_classes=self.num_classes)
 | 
					 | 
				
			||||||
        self.component_layer = self.proto_layer
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					        similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
 | 
				
			||||||
    def components(self):
 | 
					        components_initializer = kwargs.get("components_initializer", None)
 | 
				
			||||||
        return self.prototypes
 | 
					        reasonings_initializer = kwargs.get("reasonings_initializer",
 | 
				
			||||||
 | 
					                                            RandomReasoningsInitializer())
 | 
				
			||||||
 | 
					        self.components_layer = ReasoningComponents(
 | 
				
			||||||
 | 
					            self.hparams.distribution,
 | 
				
			||||||
 | 
					            components_initializer=components_initializer,
 | 
				
			||||||
 | 
					            reasonings_initializer=reasonings_initializer,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.similarity_layer = LambdaLayer(similarity_fn)
 | 
				
			||||||
 | 
					        self.competition_layer = CBCC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					        # Namespace hook
 | 
				
			||||||
    def reasonings(self):
 | 
					        self.proto_layer = self.components_layer
 | 
				
			||||||
        return self.reasoning_layer.reasonings.cpu()
 | 
					
 | 
				
			||||||
 | 
					        self.loss = MarginLoss(self.hparams.margin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        components, _ = self.component_layer()
 | 
					        components, reasonings = self.components_layer()
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        self.backbone.requires_grad_(self.both_path_gradients)
 | 
					        self.backbone.requires_grad_(self.both_path_gradients)
 | 
				
			||||||
        latent_components = self.backbone(components)
 | 
					        latent_components = self.backbone(components)
 | 
				
			||||||
        self.backbone.requires_grad_(True)
 | 
					        self.backbone.requires_grad_(True)
 | 
				
			||||||
        detections = self.similarity_fn(latent_x, latent_components)
 | 
					        detections = self.similarity_layer(latent_x, latent_components)
 | 
				
			||||||
        probs = self.reasoning_layer(detections)
 | 
					        probs = self.competition_layer(detections, reasonings)
 | 
				
			||||||
        return probs
 | 
					        return probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        x, y = batch
 | 
					        x, y = batch
 | 
				
			||||||
        # x = x.view(x.size(0), -1)
 | 
					 | 
				
			||||||
        y_pred = self(x)
 | 
					        y_pred = self(x)
 | 
				
			||||||
        num_classes = self.reasoning_layer.num_classes
 | 
					        num_classes = self.num_classes
 | 
				
			||||||
        y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
 | 
					        y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
 | 
				
			||||||
        loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
 | 
					        loss = self.loss(y_pred, y_true).mean(dim=0)
 | 
				
			||||||
        return y_pred, loss
 | 
					        return y_pred, loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
@@ -76,7 +75,3 @@ class ImageCBC(ImagePrototypesMixin, CBC):
 | 
				
			|||||||
    """CBC model that constrains the components to the range [0, 1] by
 | 
					    """CBC model that constrains the components to the range [0, 1] by
 | 
				
			||||||
    clamping after updates.
 | 
					    clamping after updates.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					 | 
				
			||||||
        # Namespace hook
 | 
					 | 
				
			||||||
        self.proto_layer = self.component_layer
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,33 +6,12 @@ Modules not yet available in prototorch go here temporarily.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.distances import euclidean_distance
 | 
					from ..core.similarities import gaussian
 | 
				
			||||||
from ..core.similarities import cosine_similarity
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def rescaled_cosine_similarity(x, y):
 | 
					 | 
				
			||||||
    """Cosine Similarity rescaled to [0, 1]."""
 | 
					 | 
				
			||||||
    similarities = cosine_similarity(x, y)
 | 
					 | 
				
			||||||
    return (similarities + 1.0) / 2.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def shift_activation(x):
 | 
					 | 
				
			||||||
    return (x + 1.0) / 2.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def euclidean_similarity(x, y, variance=1.0):
 | 
					 | 
				
			||||||
    d = euclidean_distance(x, y)
 | 
					 | 
				
			||||||
    return torch.exp(-(d * d) / (2 * variance))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def gaussian(distances, variance):
 | 
					 | 
				
			||||||
    return torch.exp(-(distances * distances) / (2 * variance))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rank_scaled_gaussian(distances, lambd):
 | 
					def rank_scaled_gaussian(distances, lambd):
 | 
				
			||||||
    order = torch.argsort(distances, dim=1)
 | 
					    order = torch.argsort(distances, dim=1)
 | 
				
			||||||
    ranks = torch.argsort(order, dim=1)
 | 
					    ranks = torch.argsort(order, dim=1)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    return torch.exp(-torch.exp(-ranks / lambd) * distances)
 | 
					    return torch.exp(-torch.exp(-ranks / lambd) * distances)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -109,64 +88,3 @@ class ConnectionTopology(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def extra_repr(self):
 | 
					    def extra_repr(self):
 | 
				
			||||||
        return f"(agelimit): ({self.agelimit})"
 | 
					        return f"(agelimit): ({self.agelimit})"
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class CosineSimilarity(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, activation=shift_activation):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.activation = activation
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, x, y):
 | 
					 | 
				
			||||||
        epsilon = torch.finfo(x.dtype).eps
 | 
					 | 
				
			||||||
        normed_x = (x / x.pow(2).sum(dim=tuple(range(
 | 
					 | 
				
			||||||
            1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
 | 
					 | 
				
			||||||
                start_dim=1)
 | 
					 | 
				
			||||||
        normed_y = (y / y.pow(2).sum(dim=tuple(range(
 | 
					 | 
				
			||||||
            1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
 | 
					 | 
				
			||||||
                start_dim=1)
 | 
					 | 
				
			||||||
        # normed_x = (x / torch.linalg.norm(x, dim=1))
 | 
					 | 
				
			||||||
        diss = torch.inner(normed_x, normed_y)
 | 
					 | 
				
			||||||
        return self.activation(diss)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class MarginLoss(torch.nn.modules.loss._Loss):
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 margin=0.3,
 | 
					 | 
				
			||||||
                 size_average=None,
 | 
					 | 
				
			||||||
                 reduce=None,
 | 
					 | 
				
			||||||
                 reduction="mean"):
 | 
					 | 
				
			||||||
        super().__init__(size_average, reduce, reduction)
 | 
					 | 
				
			||||||
        self.margin = margin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, input_, target):
 | 
					 | 
				
			||||||
        dp = torch.sum(target * input_, dim=-1)
 | 
					 | 
				
			||||||
        dm = torch.max(input_ - target, dim=-1).values
 | 
					 | 
				
			||||||
        return torch.nn.functional.relu(dm - dp + self.margin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ReasoningLayer(torch.nn.Module):
 | 
					 | 
				
			||||||
    def __init__(self, num_components, num_classes, num_replicas=1):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.num_replicas = num_replicas
 | 
					 | 
				
			||||||
        self.num_classes = num_classes
 | 
					 | 
				
			||||||
        probabilities_init = torch.zeros(2, 1, num_components,
 | 
					 | 
				
			||||||
                                         self.num_classes)
 | 
					 | 
				
			||||||
        probabilities_init.uniform_(0.4, 0.6)
 | 
					 | 
				
			||||||
        # TODO Use `self.register_parameter("param", Paramater(param))` instead
 | 
					 | 
				
			||||||
        self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def reasonings(self):
 | 
					 | 
				
			||||||
        pk = self.reasoning_probabilities[0]
 | 
					 | 
				
			||||||
        nk = (1 - pk) * self.reasoning_probabilities[1]
 | 
					 | 
				
			||||||
        ik = 1 - pk - nk
 | 
					 | 
				
			||||||
        img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
 | 
					 | 
				
			||||||
        return img.unsqueeze(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, detections):
 | 
					 | 
				
			||||||
        pk = self.reasoning_probabilities[0].clamp(0, 1)
 | 
					 | 
				
			||||||
        nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
 | 
					 | 
				
			||||||
        numerator = (detections @ (pk - nk)) + nk.sum(1)
 | 
					 | 
				
			||||||
        probs = numerator / (pk + nk).sum(1)
 | 
					 | 
				
			||||||
        probs = probs.squeeze(0)
 | 
					 | 
				
			||||||
        return probs
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,7 +4,8 @@ import torch
 | 
				
			|||||||
from torch.nn.parameter import Parameter
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.competitions import wtac
 | 
					from ..core.competitions import wtac
 | 
				
			||||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
					from ..core.distances import (lomega_distance, omega_distance,
 | 
				
			||||||
 | 
					                              squared_euclidean_distance)
 | 
				
			||||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
				
			||||||
from ..nn.activations import get_activation
 | 
					from ..nn.activations import get_activation
 | 
				
			||||||
from ..nn.wrappers import LambdaLayer, LossLayer
 | 
					from ..nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
@@ -27,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel):
 | 
				
			|||||||
        # Loss
 | 
					        # Loss
 | 
				
			||||||
        self.loss = LossLayer(glvq_loss)
 | 
					        self.loss = LossLayer(glvq_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Prototype metrics
 | 
					 | 
				
			||||||
        self.initialize_prototype_win_ratios()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def initialize_prototype_win_ratios(self):
 | 
					    def initialize_prototype_win_ratios(self):
 | 
				
			||||||
        self.register_buffer(
 | 
					        self.register_buffer(
 | 
				
			||||||
            "prototype_win_ratios",
 | 
					            "prototype_win_ratios",
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -199,7 +199,7 @@ class VisCBC2D(Vis2DAbstract):
 | 
				
			|||||||
        self.plot_protos(ax, protos, "w")
 | 
					        self.plot_protos(ax, protos, "w")
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        mesh_input, xx, yy = self.get_mesh_input(x)
 | 
					        mesh_input, xx, yy = self.get_mesh_input(x)
 | 
				
			||||||
        _components = pl_module.component_layer._components
 | 
					        _components = pl_module.components_layer._components
 | 
				
			||||||
        y_pred = pl_module.predict(
 | 
					        y_pred = pl_module.predict(
 | 
				
			||||||
            torch.Tensor(mesh_input).type_as(_components))
 | 
					            torch.Tensor(mesh_input).type_as(_components))
 | 
				
			||||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user