[BUGFIX] examples/cbc_iris.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-15 15:59:47 +02:00
parent 1b420c1f6b
commit a37095409b
6 changed files with 39 additions and 125 deletions

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -24,14 +23,18 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[2, 2, 2], distribution=[1, 0, 3],
proto_lr=0.1, margin=0.1,
proto_lr=0.01,
bb_lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.CBC( model = pt.models.CBC(
hparams, hparams,
prototype_initializer=pt.components.SSI(train_ds, noise=0.01), components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
reasonings_iniitializer=pt.initializers.
PurePositiveReasoningsInitializer(),
) )
# Callbacks # Callbacks

View File

@ -128,7 +128,7 @@ class SupervisedPrototypeModel(PrototypeModel):
@property @property
def num_classes(self): def num_classes(self):
return len(self.proto_layer.distribution) return self.proto_layer.num_classes
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()

View File

@ -1,55 +1,54 @@
import torch import torch
import torchmetrics import torchmetrics
from ..core.competitions import CBCC
from ..core.components import ReasoningComponents from ..core.components import ReasoningComponents
from ..core.initializers import RandomReasoningsInitializer
from ..core.losses import MarginLoss
from ..core.similarities import euclidean_similarity
from ..nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin from .abstract import ImagePrototypesMixin
from .extras import (
CosineSimilarity,
MarginLoss,
ReasoningLayer,
euclidean_similarity,
rescaled_cosine_similarity,
shift_activation,
)
from .glvq import SiameseGLVQ from .glvq import SiameseGLVQ
class CBC(SiameseGLVQ): class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
def __init__(self, hparams, margin=0.1, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.margin = margin
self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
num_components = self.components.shape[0]
self.reasoning_layer = ReasoningLayer(num_components=num_components,
num_classes=self.num_classes)
self.component_layer = self.proto_layer
@property similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
def components(self): components_initializer = kwargs.get("components_initializer", None)
return self.prototypes reasonings_initializer = kwargs.get("reasonings_initializer",
RandomReasoningsInitializer())
self.components_layer = ReasoningComponents(
self.hparams.distribution,
components_initializer=components_initializer,
reasonings_initializer=reasonings_initializer,
)
self.similarity_layer = LambdaLayer(similarity_fn)
self.competition_layer = CBCC()
@property # Namespace hook
def reasonings(self): self.proto_layer = self.components_layer
return self.reasoning_layer.reasonings.cpu()
self.loss = MarginLoss(self.hparams.margin)
def forward(self, x): def forward(self, x):
components, _ = self.component_layer() components, reasonings = self.components_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
latent_components = self.backbone(components) latent_components = self.backbone(components)
self.backbone.requires_grad_(True) self.backbone.requires_grad_(True)
detections = self.similarity_fn(latent_x, latent_components) detections = self.similarity_layer(latent_x, latent_components)
probs = self.reasoning_layer(detections) probs = self.competition_layer(detections, reasonings)
return probs return probs
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
# x = x.view(x.size(0), -1)
y_pred = self(x) y_pred = self(x)
num_classes = self.reasoning_layer.num_classes num_classes = self.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes) y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0) loss = self.loss(y_pred, y_true).mean(dim=0)
return y_pred, loss return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
@ -76,7 +75,3 @@ class ImageCBC(ImagePrototypesMixin, CBC):
"""CBC model that constrains the components to the range [0, 1] by """CBC model that constrains the components to the range [0, 1] by
clamping after updates. clamping after updates.
""" """
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Namespace hook
self.proto_layer = self.component_layer

View File

@ -6,33 +6,12 @@ Modules not yet available in prototorch go here temporarily.
import torch import torch
from ..core.distances import euclidean_distance from ..core.similarities import gaussian
from ..core.similarities import cosine_similarity
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
similarities = cosine_similarity(x, y)
return (similarities + 1.0) / 2.0
def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y, variance=1.0):
d = euclidean_distance(x, y)
return torch.exp(-(d * d) / (2 * variance))
def gaussian(distances, variance):
return torch.exp(-(distances * distances) / (2 * variance))
def rank_scaled_gaussian(distances, lambd): def rank_scaled_gaussian(distances, lambd):
order = torch.argsort(distances, dim=1) order = torch.argsort(distances, dim=1)
ranks = torch.argsort(order, dim=1) ranks = torch.argsort(order, dim=1)
return torch.exp(-torch.exp(-ranks / lambd) * distances) return torch.exp(-torch.exp(-ranks / lambd) * distances)
@ -109,64 +88,3 @@ class ConnectionTopology(torch.nn.Module):
def extra_repr(self): def extra_repr(self):
return f"(agelimit): ({self.agelimit})" return f"(agelimit): ({self.agelimit})"
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, num_components, num_classes, num_replicas=1):
super().__init__()
self.num_replicas = num_replicas
self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6)
# TODO Use `self.register_parameter("param", Paramater(param))` instead
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs

View File

@ -4,7 +4,8 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..core.competitions import wtac from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance from ..core.distances import (lomega_distance, omega_distance,
squared_euclidean_distance)
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..nn.activations import get_activation from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer from ..nn.wrappers import LambdaLayer, LossLayer
@ -27,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel):
# Loss # Loss
self.loss = LossLayer(glvq_loss) self.loss = LossLayer(glvq_loss)
# Prototype metrics
self.initialize_prototype_win_ratios()
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios",

View File

@ -199,7 +199,7 @@ class VisCBC2D(Vis2DAbstract):
self.plot_protos(ax, protos, "w") self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
_components = pl_module.component_layer._components _components = pl_module.components_layer._components
y_pred = pl_module.predict( y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components)) torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape) y_pred = y_pred.cpu().reshape(xx.shape)