Automatic Formating.

This commit is contained in:
Alexander Engelsberger
2021-04-23 17:27:47 +02:00
parent db4499a103
commit c4c51a16fe
12 changed files with 404 additions and 159 deletions

View File

@@ -1,11 +1,9 @@
import argparse
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.initializers import get_initializer
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
@@ -54,12 +52,14 @@ class GLVQ(pl.LightningModule):
self.train_acc(
preds.int(),
y.int()) # FloatTensors are assumed to be class probabilities
self.log("acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
self.log(
"acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
# def training_epoch_end(self, outs):
@@ -81,4 +81,4 @@ class ImageGLVQ(GLVQ):
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0., 1.)
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)