Import from the newly cleaned-up prototorch namespace
This commit is contained in:
parent
c87ed5ba8b
commit
69e5ff3243
@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version
|
|||||||
|
|
||||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||||
from .cbc import CBC, ImageCBC
|
from .cbc import CBC, ImageCBC
|
||||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
from .glvq import (
|
||||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
GLVQ,
|
||||||
|
GLVQ1,
|
||||||
|
GLVQ21,
|
||||||
|
GMLVQ,
|
||||||
|
GRLVQ,
|
||||||
|
LGMLVQ,
|
||||||
|
LVQMLN,
|
||||||
|
ImageGLVQ,
|
||||||
|
ImageGMLVQ,
|
||||||
|
SiameseGLVQ,
|
||||||
|
SiameseGMLVQ,
|
||||||
|
)
|
||||||
from .knn import KNN
|
from .knn import KNN
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||||
|
@ -5,9 +5,12 @@ from typing import Final, final
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import Components, LabeledComponents
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from ..core.competitions import WTAC
|
||||||
from prototorch.modules import WTAC, LambdaLayer
|
from ..core.components import Components, LabeledComponents
|
||||||
|
from ..core.distances import euclidean_distance
|
||||||
|
from ..core.pooling import stratified_min_pooling
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin(object):
|
class ProtoTorchMixin(object):
|
||||||
|
@ -4,8 +4,8 @@ import logging
|
|||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from prototorch.components import Components
|
|
||||||
|
|
||||||
|
from ..core.components import Components
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
|
from ..core.components import ReasoningComponents
|
||||||
from .abstract import ImagePrototypesMixin
|
from .abstract import ImagePrototypesMixin
|
||||||
from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer,
|
from .extras import (
|
||||||
euclidean_similarity, rescaled_cosine_similarity,
|
CosineSimilarity,
|
||||||
shift_activation)
|
MarginLoss,
|
||||||
|
ReasoningLayer,
|
||||||
|
euclidean_similarity,
|
||||||
|
rescaled_cosine_similarity,
|
||||||
|
shift_activation,
|
||||||
|
)
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,8 +5,9 @@ Modules not yet available in prototorch go here temporarily.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.distances import euclidean_distance
|
|
||||||
from prototorch.functions.similarities import cosine_similarity
|
from ..core.distances import euclidean_distance
|
||||||
|
from ..core.similarities import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
def rescaled_cosine_similarity(x, y):
|
def rescaled_cosine_similarity(x, y):
|
||||||
@ -24,6 +25,35 @@ def euclidean_similarity(x, y, variance=1.0):
|
|||||||
return torch.exp(-(d * d) / (2 * variance))
|
return torch.exp(-(d * d) / (2 * variance))
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian(distances, variance):
|
||||||
|
return torch.exp(-(distances * distances) / (2 * variance))
|
||||||
|
|
||||||
|
|
||||||
|
def rank_scaled_gaussian(distances, lambd):
|
||||||
|
order = torch.argsort(distances, dim=1)
|
||||||
|
ranks = torch.argsort(order, dim=1)
|
||||||
|
|
||||||
|
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianPrior(torch.nn.Module):
|
||||||
|
def __init__(self, variance):
|
||||||
|
super().__init__()
|
||||||
|
self.variance = variance
|
||||||
|
|
||||||
|
def forward(self, distances):
|
||||||
|
return gaussian(distances, self.variance)
|
||||||
|
|
||||||
|
|
||||||
|
class RankScaledGaussianPrior(torch.nn.Module):
|
||||||
|
def __init__(self, lambd):
|
||||||
|
super().__init__()
|
||||||
|
self.lambd = lambd
|
||||||
|
|
||||||
|
def forward(self, distances):
|
||||||
|
return rank_scaled_gaussian(distances, self.lambd)
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTopology(torch.nn.Module):
|
class ConnectionTopology(torch.nn.Module):
|
||||||
def __init__(self, agelimit, num_prototypes):
|
def __init__(self, agelimit, num_prototypes):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,18 +1,13 @@
|
|||||||
"""Models based on the GLVQ framework."""
|
"""Models based on the GLVQ framework."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.activations import get_activation
|
|
||||||
from prototorch.functions.competitions import wtac
|
|
||||||
from prototorch.functions.distances import (
|
|
||||||
lomega_distance,
|
|
||||||
omega_distance,
|
|
||||||
squared_euclidean_distance,
|
|
||||||
)
|
|
||||||
from prototorch.functions.helper import get_flat
|
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
|
||||||
from prototorch.modules import LambdaLayer, LossLayer
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from ..core.competitions import wtac
|
||||||
|
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||||
|
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
from ..nn.activations import get_activation
|
||||||
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +132,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = get_flat(x, protos)
|
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents
|
from ..core.competitions import KNNC
|
||||||
from prototorch.modules import KNNC
|
from ..core.components import LabeledComponents
|
||||||
|
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
|
||||||
|
from ..utils.utils import parse_data_arg
|
||||||
from .abstract import SupervisedPrototypeModel
|
from .abstract import SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""LVQ models that are optimized using non-gradient methods."""
|
"""LVQ models that are optimized using non-gradient methods."""
|
||||||
|
|
||||||
from prototorch.functions.losses import _get_dp_dm
|
from ..core.losses import _get_dp_dm
|
||||||
|
|
||||||
from .abstract import NonGradientMixin
|
from .abstract import NonGradientMixin
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
"""Probabilistic GLVQ methods"""
|
"""Probabilistic GLVQ methods"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
|
||||||
from prototorch.functions.pooling import (stratified_min_pooling,
|
|
||||||
stratified_sum_pooling)
|
|
||||||
from prototorch.functions.transforms import (GaussianPrior,
|
|
||||||
RankScaledGaussianPrior)
|
|
||||||
from prototorch.modules import LambdaLayer, LossLayer
|
|
||||||
|
|
||||||
|
from ..core.losses import nllr_loss, rslvq_loss
|
||||||
|
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||||
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
|
from .extras import GaussianPrior, RankScaledGaussianPrior
|
||||||
from .glvq import GLVQ, SiameseGMLVQ
|
from .glvq import GLVQ, SiameseGMLVQ
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,11 +2,11 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.competitions import wtac
|
|
||||||
from prototorch.functions.distances import squared_euclidean_distance
|
|
||||||
from prototorch.modules import LambdaLayer
|
|
||||||
from prototorch.modules.losses import NeuralGasEnergy
|
|
||||||
|
|
||||||
|
from ..core.competitions import wtac
|
||||||
|
from ..core.distances import squared_euclidean_distance
|
||||||
|
from ..core.losses import NeuralGasEnergy
|
||||||
|
from ..nn.wrappers import LambdaLayer
|
||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||||
from .callbacks import GNGCallback
|
from .callbacks import GNGCallback
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
Loading…
Reference in New Issue
Block a user