Import from the newly cleaned-up prototorch namespace

This commit is contained in:
Jensun Ravichandran
2021-06-14 20:08:08 +02:00
parent c87ed5ba8b
commit 69e5ff3243
10 changed files with 80 additions and 37 deletions

View File

@@ -1,18 +1,13 @@
"""Models based on the GLVQ framework."""
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
@@ -137,7 +132,7 @@ class SiameseGLVQ(GLVQ):
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)