Update Examples to new initializer architecture.
Visualization still borken for some examples.
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from prototorch.components.components import Components
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.similarities import cosine_similarity
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
|
||||
def rescaled_cosine_similarity(x, y):
|
||||
@@ -93,12 +92,8 @@ class CBC(pl.LightningModule):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(hparams)
|
||||
self.margin = margin
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=self.hparams.input_dim,
|
||||
nclasses=self.hparams.nclasses,
|
||||
prototypes_per_class=self.hparams.prototypes_per_class,
|
||||
prototype_initializer=self.hparams.prototype_initializer,
|
||||
**kwargs)
|
||||
self.component_layer = Components(self.hparams.num_components,
|
||||
self.hparams.component_initializer)
|
||||
# self.similarity = CosineSimilarity()
|
||||
self.similarity = similarity
|
||||
self.backbone = backbone_class()
|
||||
@@ -110,7 +105,7 @@ class CBC(pl.LightningModule):
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
return self.proto_layer.prototypes.detach().cpu()
|
||||
return self.component_layer.components.detach().cpu()
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
@@ -126,7 +121,7 @@ class CBC(pl.LightningModule):
|
||||
|
||||
def forward(self, x):
|
||||
self.sync_backbones()
|
||||
protos, _ = self.proto_layer()
|
||||
protos = self.component_layer()
|
||||
|
||||
latent_x = self.backbone(x)
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
@@ -167,4 +162,4 @@ class ImageCBC(CBC):
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
||||
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
|
||||
self.component_layer.prototypes.data.clamp_(0.0, 1.0)
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
@@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (euclidean_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
||||
@@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
with torch.no_grad():
|
||||
preds = wtac(dis, plabels)
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
self.train_acc(preds.int(), y.int())
|
||||
|
||||
# Logging
|
||||
self.log("train_loss", loss)
|
||||
|
@@ -1,9 +1,7 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from prototorch.components import Components
|
||||
from prototorch.components import initializers as cinit
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules import Prototypes1D
|
||||
from prototorch.modules.losses import NeuralGasEnergy
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
Reference in New Issue
Block a user