Import from the newly cleaned-up prototorch namespace
This commit is contained in:
		@@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
 | 
					from .callbacks import PrototypeConvergence, PruneLoserPrototypes
 | 
				
			||||||
from .cbc import CBC, ImageCBC
 | 
					from .cbc import CBC, ImageCBC
 | 
				
			||||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
 | 
					from .glvq import (
 | 
				
			||||||
                   ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
 | 
					    GLVQ,
 | 
				
			||||||
 | 
					    GLVQ1,
 | 
				
			||||||
 | 
					    GLVQ21,
 | 
				
			||||||
 | 
					    GMLVQ,
 | 
				
			||||||
 | 
					    GRLVQ,
 | 
				
			||||||
 | 
					    LGMLVQ,
 | 
				
			||||||
 | 
					    LVQMLN,
 | 
				
			||||||
 | 
					    ImageGLVQ,
 | 
				
			||||||
 | 
					    ImageGMLVQ,
 | 
				
			||||||
 | 
					    SiameseGLVQ,
 | 
				
			||||||
 | 
					    SiameseGMLVQ,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from .knn import KNN
 | 
					from .knn import KNN
 | 
				
			||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
					from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
				
			||||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
 | 
					from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,9 +5,12 @@ from typing import Final, final
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
from prototorch.components import Components, LabeledComponents
 | 
					
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					from ..core.competitions import WTAC
 | 
				
			||||||
from prototorch.modules import WTAC, LambdaLayer
 | 
					from ..core.components import Components, LabeledComponents
 | 
				
			||||||
 | 
					from ..core.distances import euclidean_distance
 | 
				
			||||||
 | 
					from ..core.pooling import stratified_min_pooling
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProtoTorchMixin(object):
 | 
					class ProtoTorchMixin(object):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,8 +4,8 @@ import logging
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.components import Components
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.components import Components
 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,10 +1,16 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.components import ReasoningComponents
 | 
				
			||||||
from .abstract import ImagePrototypesMixin
 | 
					from .abstract import ImagePrototypesMixin
 | 
				
			||||||
from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer,
 | 
					from .extras import (
 | 
				
			||||||
                     euclidean_similarity, rescaled_cosine_similarity,
 | 
					    CosineSimilarity,
 | 
				
			||||||
                     shift_activation)
 | 
					    MarginLoss,
 | 
				
			||||||
 | 
					    ReasoningLayer,
 | 
				
			||||||
 | 
					    euclidean_similarity,
 | 
				
			||||||
 | 
					    rescaled_cosine_similarity,
 | 
				
			||||||
 | 
					    shift_activation,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from .glvq import SiameseGLVQ
 | 
					from .glvq import SiameseGLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,8 +5,9 @@ Modules not yet available in prototorch go here temporarily.
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					
 | 
				
			||||||
from prototorch.functions.similarities import cosine_similarity
 | 
					from ..core.distances import euclidean_distance
 | 
				
			||||||
 | 
					from ..core.similarities import cosine_similarity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rescaled_cosine_similarity(x, y):
 | 
					def rescaled_cosine_similarity(x, y):
 | 
				
			||||||
@@ -24,6 +25,35 @@ def euclidean_similarity(x, y, variance=1.0):
 | 
				
			|||||||
    return torch.exp(-(d * d) / (2 * variance))
 | 
					    return torch.exp(-(d * d) / (2 * variance))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gaussian(distances, variance):
 | 
				
			||||||
 | 
					    return torch.exp(-(distances * distances) / (2 * variance))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rank_scaled_gaussian(distances, lambd):
 | 
				
			||||||
 | 
					    order = torch.argsort(distances, dim=1)
 | 
				
			||||||
 | 
					    ranks = torch.argsort(order, dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return torch.exp(-torch.exp(-ranks / lambd) * distances)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GaussianPrior(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, variance):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.variance = variance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, distances):
 | 
				
			||||||
 | 
					        return gaussian(distances, self.variance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class RankScaledGaussianPrior(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, lambd):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.lambd = lambd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, distances):
 | 
				
			||||||
 | 
					        return rank_scaled_gaussian(distances, self.lambd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ConnectionTopology(torch.nn.Module):
 | 
					class ConnectionTopology(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, agelimit, num_prototypes):
 | 
					    def __init__(self, agelimit, num_prototypes):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,18 +1,13 @@
 | 
				
			|||||||
"""Models based on the GLVQ framework."""
 | 
					"""Models based on the GLVQ framework."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.activations import get_activation
 | 
					 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					 | 
				
			||||||
from prototorch.functions.distances import (
 | 
					 | 
				
			||||||
    lomega_distance,
 | 
					 | 
				
			||||||
    omega_distance,
 | 
					 | 
				
			||||||
    squared_euclidean_distance,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from prototorch.functions.helper import get_flat
 | 
					 | 
				
			||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					 | 
				
			||||||
from prototorch.modules import LambdaLayer, LossLayer
 | 
					 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.competitions import wtac
 | 
				
			||||||
 | 
					from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
				
			||||||
 | 
					from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
				
			||||||
 | 
					from ..nn.activations import get_activation
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
					from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -137,7 +132,7 @@ class SiameseGLVQ(GLVQ):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def compute_distances(self, x):
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
        protos, _ = self.proto_layer()
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
        x, protos = get_flat(x, protos)
 | 
					        x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        self.backbone.requires_grad_(self.both_path_gradients)
 | 
					        self.backbone.requires_grad_(self.both_path_gradients)
 | 
				
			||||||
        latent_protos = self.backbone(protos)
 | 
					        latent_protos = self.backbone(protos)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,9 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.components import LabeledComponents
 | 
					from ..core.competitions import KNNC
 | 
				
			||||||
from prototorch.modules import KNNC
 | 
					from ..core.components import LabeledComponents
 | 
				
			||||||
 | 
					from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
 | 
				
			||||||
 | 
					from ..utils.utils import parse_data_arg
 | 
				
			||||||
from .abstract import SupervisedPrototypeModel
 | 
					from .abstract import SupervisedPrototypeModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,6 @@
 | 
				
			|||||||
"""LVQ models that are optimized using non-gradient methods."""
 | 
					"""LVQ models that are optimized using non-gradient methods."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.functions.losses import _get_dp_dm
 | 
					from ..core.losses import _get_dp_dm
 | 
				
			||||||
 | 
					 | 
				
			||||||
from .abstract import NonGradientMixin
 | 
					from .abstract import NonGradientMixin
 | 
				
			||||||
from .glvq import GLVQ
 | 
					from .glvq import GLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,13 +1,11 @@
 | 
				
			|||||||
"""Probabilistic GLVQ methods"""
 | 
					"""Probabilistic GLVQ methods"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
 | 
					 | 
				
			||||||
from prototorch.functions.pooling import (stratified_min_pooling,
 | 
					 | 
				
			||||||
                                          stratified_sum_pooling)
 | 
					 | 
				
			||||||
from prototorch.functions.transforms import (GaussianPrior,
 | 
					 | 
				
			||||||
                                             RankScaledGaussianPrior)
 | 
					 | 
				
			||||||
from prototorch.modules import LambdaLayer, LossLayer
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.losses import nllr_loss, rslvq_loss
 | 
				
			||||||
 | 
					from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
 | 
					from .extras import GaussianPrior, RankScaledGaussianPrior
 | 
				
			||||||
from .glvq import GLVQ, SiameseGMLVQ
 | 
					from .glvq import GLVQ, SiameseGMLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					 | 
				
			||||||
from prototorch.functions.distances import squared_euclidean_distance
 | 
					 | 
				
			||||||
from prototorch.modules import LambdaLayer
 | 
					 | 
				
			||||||
from prototorch.modules.losses import NeuralGasEnergy
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..core.competitions import wtac
 | 
				
			||||||
 | 
					from ..core.distances import squared_euclidean_distance
 | 
				
			||||||
 | 
					from ..core.losses import NeuralGasEnergy
 | 
				
			||||||
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
 | 
					from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
 | 
				
			||||||
from .callbacks import GNGCallback
 | 
					from .callbacks import GNGCallback
 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user