[BUGFIX] examples/cbc_iris.py
works again
This commit is contained in:
parent
1b420c1f6b
commit
a37095409b
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -24,14 +23,18 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[2, 2, 2],
|
distribution=[1, 0, 3],
|
||||||
proto_lr=0.1,
|
margin=0.1,
|
||||||
|
proto_lr=0.01,
|
||||||
|
bb_lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.CBC(
|
model = pt.models.CBC(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
|
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
|
||||||
|
reasonings_iniitializer=pt.initializers.
|
||||||
|
PurePositiveReasoningsInitializer(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
|
@ -128,7 +128,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
return len(self.proto_layer.distribution)
|
return self.proto_layer.num_classes
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
|
@ -1,55 +1,54 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
|
from ..core.competitions import CBCC
|
||||||
from ..core.components import ReasoningComponents
|
from ..core.components import ReasoningComponents
|
||||||
|
from ..core.initializers import RandomReasoningsInitializer
|
||||||
|
from ..core.losses import MarginLoss
|
||||||
|
from ..core.similarities import euclidean_similarity
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
from .abstract import ImagePrototypesMixin
|
from .abstract import ImagePrototypesMixin
|
||||||
from .extras import (
|
|
||||||
CosineSimilarity,
|
|
||||||
MarginLoss,
|
|
||||||
ReasoningLayer,
|
|
||||||
euclidean_similarity,
|
|
||||||
rescaled_cosine_similarity,
|
|
||||||
shift_activation,
|
|
||||||
)
|
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
class CBC(SiameseGLVQ):
|
class CBC(SiameseGLVQ):
|
||||||
"""Classification-By-Components."""
|
"""Classification-By-Components."""
|
||||||
def __init__(self, hparams, margin=0.1, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.margin = margin
|
|
||||||
self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
|
||||||
num_components = self.components.shape[0]
|
|
||||||
self.reasoning_layer = ReasoningLayer(num_components=num_components,
|
|
||||||
num_classes=self.num_classes)
|
|
||||||
self.component_layer = self.proto_layer
|
|
||||||
|
|
||||||
@property
|
similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||||
def components(self):
|
components_initializer = kwargs.get("components_initializer", None)
|
||||||
return self.prototypes
|
reasonings_initializer = kwargs.get("reasonings_initializer",
|
||||||
|
RandomReasoningsInitializer())
|
||||||
|
self.components_layer = ReasoningComponents(
|
||||||
|
self.hparams.distribution,
|
||||||
|
components_initializer=components_initializer,
|
||||||
|
reasonings_initializer=reasonings_initializer,
|
||||||
|
)
|
||||||
|
self.similarity_layer = LambdaLayer(similarity_fn)
|
||||||
|
self.competition_layer = CBCC()
|
||||||
|
|
||||||
@property
|
# Namespace hook
|
||||||
def reasonings(self):
|
self.proto_layer = self.components_layer
|
||||||
return self.reasoning_layer.reasonings.cpu()
|
|
||||||
|
self.loss = MarginLoss(self.hparams.margin)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
components, _ = self.component_layer()
|
components, reasonings = self.components_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_components = self.backbone(components)
|
latent_components = self.backbone(components)
|
||||||
self.backbone.requires_grad_(True)
|
self.backbone.requires_grad_(True)
|
||||||
detections = self.similarity_fn(latent_x, latent_components)
|
detections = self.similarity_layer(latent_x, latent_components)
|
||||||
probs = self.reasoning_layer(detections)
|
probs = self.competition_layer(detections, reasonings)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
# x = x.view(x.size(0), -1)
|
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
num_classes = self.reasoning_layer.num_classes
|
num_classes = self.num_classes
|
||||||
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
|
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
|
||||||
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
|
loss = self.loss(y_pred, y_true).mean(dim=0)
|
||||||
return y_pred, loss
|
return y_pred, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
@ -76,7 +75,3 @@ class ImageCBC(ImagePrototypesMixin, CBC):
|
|||||||
"""CBC model that constrains the components to the range [0, 1] by
|
"""CBC model that constrains the components to the range [0, 1] by
|
||||||
clamping after updates.
|
clamping after updates.
|
||||||
"""
|
"""
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
# Namespace hook
|
|
||||||
self.proto_layer = self.component_layer
|
|
||||||
|
@ -6,33 +6,12 @@ Modules not yet available in prototorch go here temporarily.
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..core.distances import euclidean_distance
|
from ..core.similarities import gaussian
|
||||||
from ..core.similarities import cosine_similarity
|
|
||||||
|
|
||||||
|
|
||||||
def rescaled_cosine_similarity(x, y):
|
|
||||||
"""Cosine Similarity rescaled to [0, 1]."""
|
|
||||||
similarities = cosine_similarity(x, y)
|
|
||||||
return (similarities + 1.0) / 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def shift_activation(x):
|
|
||||||
return (x + 1.0) / 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def euclidean_similarity(x, y, variance=1.0):
|
|
||||||
d = euclidean_distance(x, y)
|
|
||||||
return torch.exp(-(d * d) / (2 * variance))
|
|
||||||
|
|
||||||
|
|
||||||
def gaussian(distances, variance):
|
|
||||||
return torch.exp(-(distances * distances) / (2 * variance))
|
|
||||||
|
|
||||||
|
|
||||||
def rank_scaled_gaussian(distances, lambd):
|
def rank_scaled_gaussian(distances, lambd):
|
||||||
order = torch.argsort(distances, dim=1)
|
order = torch.argsort(distances, dim=1)
|
||||||
ranks = torch.argsort(order, dim=1)
|
ranks = torch.argsort(order, dim=1)
|
||||||
|
|
||||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||||
|
|
||||||
|
|
||||||
@ -109,64 +88,3 @@ class ConnectionTopology(torch.nn.Module):
|
|||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"(agelimit): ({self.agelimit})"
|
return f"(agelimit): ({self.agelimit})"
|
||||||
|
|
||||||
|
|
||||||
class CosineSimilarity(torch.nn.Module):
|
|
||||||
def __init__(self, activation=shift_activation):
|
|
||||||
super().__init__()
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
epsilon = torch.finfo(x.dtype).eps
|
|
||||||
normed_x = (x / x.pow(2).sum(dim=tuple(range(
|
|
||||||
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
|
||||||
start_dim=1)
|
|
||||||
normed_y = (y / y.pow(2).sum(dim=tuple(range(
|
|
||||||
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
|
|
||||||
start_dim=1)
|
|
||||||
# normed_x = (x / torch.linalg.norm(x, dim=1))
|
|
||||||
diss = torch.inner(normed_x, normed_y)
|
|
||||||
return self.activation(diss)
|
|
||||||
|
|
||||||
|
|
||||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
|
||||||
def __init__(self,
|
|
||||||
margin=0.3,
|
|
||||||
size_average=None,
|
|
||||||
reduce=None,
|
|
||||||
reduction="mean"):
|
|
||||||
super().__init__(size_average, reduce, reduction)
|
|
||||||
self.margin = margin
|
|
||||||
|
|
||||||
def forward(self, input_, target):
|
|
||||||
dp = torch.sum(target * input_, dim=-1)
|
|
||||||
dm = torch.max(input_ - target, dim=-1).values
|
|
||||||
return torch.nn.functional.relu(dm - dp + self.margin)
|
|
||||||
|
|
||||||
|
|
||||||
class ReasoningLayer(torch.nn.Module):
|
|
||||||
def __init__(self, num_components, num_classes, num_replicas=1):
|
|
||||||
super().__init__()
|
|
||||||
self.num_replicas = num_replicas
|
|
||||||
self.num_classes = num_classes
|
|
||||||
probabilities_init = torch.zeros(2, 1, num_components,
|
|
||||||
self.num_classes)
|
|
||||||
probabilities_init.uniform_(0.4, 0.6)
|
|
||||||
# TODO Use `self.register_parameter("param", Paramater(param))` instead
|
|
||||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reasonings(self):
|
|
||||||
pk = self.reasoning_probabilities[0]
|
|
||||||
nk = (1 - pk) * self.reasoning_probabilities[1]
|
|
||||||
ik = 1 - pk - nk
|
|
||||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
|
||||||
return img.unsqueeze(1)
|
|
||||||
|
|
||||||
def forward(self, detections):
|
|
||||||
pk = self.reasoning_probabilities[0].clamp(0, 1)
|
|
||||||
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
|
|
||||||
numerator = (detections @ (pk - nk)) + nk.sum(1)
|
|
||||||
probs = numerator / (pk + nk).sum(1)
|
|
||||||
probs = probs.squeeze(0)
|
|
||||||
return probs
|
|
||||||
|
@ -4,7 +4,8 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ..core.competitions import wtac
|
from ..core.competitions import wtac
|
||||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
from ..core.distances import (lomega_distance, omega_distance,
|
||||||
|
squared_euclidean_distance)
|
||||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
from ..nn.activations import get_activation
|
from ..nn.activations import get_activation
|
||||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
@ -27,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
# Loss
|
# Loss
|
||||||
self.loss = LossLayer(glvq_loss)
|
self.loss = LossLayer(glvq_loss)
|
||||||
|
|
||||||
# Prototype metrics
|
|
||||||
self.initialize_prototype_win_ratios()
|
|
||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"prototype_win_ratios",
|
"prototype_win_ratios",
|
||||||
|
@ -199,7 +199,7 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
_components = pl_module.component_layer._components
|
_components = pl_module.components_layer._components
|
||||||
y_pred = pl_module.predict(
|
y_pred = pl_module.predict(
|
||||||
torch.Tensor(mesh_input).type_as(_components))
|
torch.Tensor(mesh_input).type_as(_components))
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user