Update Examples to new initializer architecture.

Visualization still borken for some examples.
This commit is contained in:
Alexander Engelsberger
2021-05-06 14:10:09 +02:00
parent d644114090
commit 1c3613019b
15 changed files with 92 additions and 248 deletions

View File

@@ -1,4 +1,3 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import LabeledComponents
@@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
from .abstract import AbstractPrototypeModel
@@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel):
with torch.no_grad():
preds = wtac(dis, plabels)
# `.int()` because FloatTensors are assumed to be class probabilities
self.train_acc(preds.int(), y.int())
# Logging
self.log("train_loss", loss)