Update Examples to new initializer architecture.

Visualization still borken for some examples.
This commit is contained in:
Alexander Engelsberger
2021-05-06 14:10:09 +02:00
parent d644114090
commit 1c3613019b
15 changed files with 92 additions and 248 deletions

View File

@@ -1,10 +1,9 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components.components import Components
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from prototorch.modules.prototypes import Prototypes1D
def rescaled_cosine_similarity(x, y):
@@ -93,12 +92,8 @@ class CBC(pl.LightningModule):
super().__init__()
self.save_hyperparameters(hparams)
self.margin = margin
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
self.component_layer = Components(self.hparams.num_components,
self.hparams.component_initializer)
# self.similarity = CosineSimilarity()
self.similarity = similarity
self.backbone = backbone_class()
@@ -110,7 +105,7 @@ class CBC(pl.LightningModule):
@property
def components(self):
return self.proto_layer.prototypes.detach().cpu()
return self.component_layer.components.detach().cpu()
@property
def reasonings(self):
@@ -126,7 +121,7 @@ class CBC(pl.LightningModule):
def forward(self, x):
self.sync_backbones()
protos, _ = self.proto_layer()
protos = self.component_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
@@ -167,4 +162,4 @@ class ImageCBC(CBC):
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
self.component_layer.prototypes.data.clamp_(0.0, 1.0)

View File

@@ -1,4 +1,3 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import LabeledComponents
@@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
from .abstract import AbstractPrototypeModel
@@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel):
with torch.no_grad():
preds = wtac(dis, plabels)
# `.int()` because FloatTensors are assumed to be class probabilities
self.train_acc(preds.int(), y.int())
# Logging
self.log("train_loss", loss)

View File

@@ -1,9 +1,7 @@
import pytorch_lightning as pl
import torch
from prototorch.components import Components
from prototorch.components import initializers as cinit
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import NeuralGasEnergy
from .abstract import AbstractPrototypeModel