Import from the newly cleaned-up prototorch namespace
This commit is contained in:
parent
c87ed5ba8b
commit
69e5ff3243
@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||
from .glvq import (
|
||||
GLVQ,
|
||||
GLVQ1,
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||
|
@ -5,9 +5,12 @@ from typing import Final, final
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import Components, LabeledComponents
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules import WTAC, LambdaLayer
|
||||
|
||||
from ..core.competitions import WTAC
|
||||
from ..core.components import Components, LabeledComponents
|
||||
from ..core.distances import euclidean_distance
|
||||
from ..core.pooling import stratified_min_pooling
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
|
@ -4,8 +4,8 @@ import logging
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from prototorch.components import Components
|
||||
|
||||
from ..core.components import Components
|
||||
from .extras import ConnectionTopology
|
||||
|
||||
|
||||
|
@ -1,10 +1,16 @@
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from ..core.components import ReasoningComponents
|
||||
from .abstract import ImagePrototypesMixin
|
||||
from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer,
|
||||
euclidean_similarity, rescaled_cosine_similarity,
|
||||
shift_activation)
|
||||
from .extras import (
|
||||
CosineSimilarity,
|
||||
MarginLoss,
|
||||
ReasoningLayer,
|
||||
euclidean_similarity,
|
||||
rescaled_cosine_similarity,
|
||||
shift_activation,
|
||||
)
|
||||
from .glvq import SiameseGLVQ
|
||||
|
||||
|
||||
|
@ -5,8 +5,9 @@ Modules not yet available in prototorch go here temporarily.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.similarities import cosine_similarity
|
||||
|
||||
from ..core.distances import euclidean_distance
|
||||
from ..core.similarities import cosine_similarity
|
||||
|
||||
|
||||
def rescaled_cosine_similarity(x, y):
|
||||
@ -24,6 +25,35 @@ def euclidean_similarity(x, y, variance=1.0):
|
||||
return torch.exp(-(d * d) / (2 * variance))
|
||||
|
||||
|
||||
def gaussian(distances, variance):
|
||||
return torch.exp(-(distances * distances) / (2 * variance))
|
||||
|
||||
|
||||
def rank_scaled_gaussian(distances, lambd):
|
||||
order = torch.argsort(distances, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
|
||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||
|
||||
|
||||
class GaussianPrior(torch.nn.Module):
|
||||
def __init__(self, variance):
|
||||
super().__init__()
|
||||
self.variance = variance
|
||||
|
||||
def forward(self, distances):
|
||||
return gaussian(distances, self.variance)
|
||||
|
||||
|
||||
class RankScaledGaussianPrior(torch.nn.Module):
|
||||
def __init__(self, lambd):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, distances):
|
||||
return rank_scaled_gaussian(distances, self.lambd)
|
||||
|
||||
|
||||
class ConnectionTopology(torch.nn.Module):
|
||||
def __init__(self, agelimit, num_prototypes):
|
||||
super().__init__()
|
||||
|
@ -1,18 +1,13 @@
|
||||
"""Models based on the GLVQ framework."""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from ..nn.activations import get_activation
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
|
||||
|
||||
@ -137,7 +132,7 @@ class SiameseGLVQ(GLVQ):
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
|
@ -2,9 +2,10 @@
|
||||
|
||||
import warnings
|
||||
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.modules import KNNC
|
||||
|
||||
from ..core.competitions import KNNC
|
||||
from ..core.components import LabeledComponents
|
||||
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
|
||||
from ..utils.utils import parse_data_arg
|
||||
from .abstract import SupervisedPrototypeModel
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""LVQ models that are optimized using non-gradient methods."""
|
||||
|
||||
from prototorch.functions.losses import _get_dp_dm
|
||||
|
||||
from ..core.losses import _get_dp_dm
|
||||
from .abstract import NonGradientMixin
|
||||
from .glvq import GLVQ
|
||||
|
||||
|
@ -1,13 +1,11 @@
|
||||
"""Probabilistic GLVQ methods"""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
||||
from prototorch.functions.pooling import (stratified_min_pooling,
|
||||
stratified_sum_pooling)
|
||||
from prototorch.functions.transforms import (GaussianPrior,
|
||||
RankScaledGaussianPrior)
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
|
||||
from ..core.losses import nllr_loss, rslvq_loss
|
||||
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .extras import GaussianPrior, RankScaledGaussianPrior
|
||||
from .glvq import GLVQ, SiameseGMLVQ
|
||||
|
||||
|
||||
|
@ -2,11 +2,11 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import squared_euclidean_distance
|
||||
from prototorch.modules import LambdaLayer
|
||||
from prototorch.modules.losses import NeuralGasEnergy
|
||||
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import squared_euclidean_distance
|
||||
from ..core.losses import NeuralGasEnergy
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||
from .callbacks import GNGCallback
|
||||
from .extras import ConnectionTopology
|
||||
|
Loading…
Reference in New Issue
Block a user