Import from the newly cleaned-up prototorch namespace
This commit is contained in:
@@ -1,18 +1,13 @@
|
||||
"""Models based on the GLVQ framework."""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from ..nn.activations import get_activation
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
|
||||
|
||||
@@ -137,7 +132,7 @@ class SiameseGLVQ(GLVQ):
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
|
Reference in New Issue
Block a user