Automatic Formating.
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
import argparse
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
@@ -54,12 +52,14 @@ class GLVQ(pl.LightningModule):
|
||||
self.train_acc(
|
||||
preds.int(),
|
||||
y.int()) # FloatTensors are assumed to be class probabilities
|
||||
self.log("acc",
|
||||
self.train_acc,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
self.log(
|
||||
"acc",
|
||||
self.train_acc,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
return loss
|
||||
|
||||
# def training_epoch_end(self, outs):
|
||||
@@ -81,4 +81,4 @@ class ImageGLVQ(GLVQ):
|
||||
clamping after updates.
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.proto_layer.prototypes.data.clamp_(0., 1.)
|
||||
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
|
||||
|
Reference in New Issue
Block a user