Update Examples to new initializer architecture.
Visualization still borken for some examples.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
@@ -7,7 +6,6 @@ from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (euclidean_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
||||
@@ -55,7 +53,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
with torch.no_grad():
|
||||
preds = wtac(dis, plabels)
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
self.train_acc(preds.int(), y.int())
|
||||
|
||||
# Logging
|
||||
self.log("train_loss", loss)
|
||||
|
Reference in New Issue
Block a user