Import from the newly cleaned-up prototorch namespace

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:08:08 +02:00
parent c87ed5ba8b
commit 69e5ff3243
10 changed files with 80 additions and 37 deletions

View File

@ -4,8 +4,19 @@ from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ

View File

@ -5,9 +5,12 @@ from typing import Final, final
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import Components, LabeledComponents
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import WTAC, LambdaLayer
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
class ProtoTorchMixin(object):

View File

@ -4,8 +4,8 @@ import logging
import pytorch_lightning as pl
import torch
from prototorch.components import Components
from ..core.components import Components
from .extras import ConnectionTopology

View File

@ -1,10 +1,16 @@
import torch
import torchmetrics
from ..core.components import ReasoningComponents
from .abstract import ImagePrototypesMixin
from .extras import (CosineSimilarity, MarginLoss, ReasoningLayer,
euclidean_similarity, rescaled_cosine_similarity,
shift_activation)
from .extras import (
CosineSimilarity,
MarginLoss,
ReasoningLayer,
euclidean_similarity,
rescaled_cosine_similarity,
shift_activation,
)
from .glvq import SiameseGLVQ

View File

@ -5,8 +5,9 @@ Modules not yet available in prototorch go here temporarily.
"""
import torch
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from ..core.distances import euclidean_distance
from ..core.similarities import cosine_similarity
def rescaled_cosine_similarity(x, y):
@ -24,6 +25,35 @@ def euclidean_similarity(x, y, variance=1.0):
return torch.exp(-(d * d) / (2 * variance))
def gaussian(distances, variance):
return torch.exp(-(distances * distances) / (2 * variance))
def rank_scaled_gaussian(distances, lambd):
order = torch.argsort(distances, dim=1)
ranks = torch.argsort(order, dim=1)
return torch.exp(-torch.exp(-ranks / lambd) * distances)
class GaussianPrior(torch.nn.Module):
def __init__(self, variance):
super().__init__()
self.variance = variance
def forward(self, distances):
return gaussian(distances, self.variance)
class RankScaledGaussianPrior(torch.nn.Module):
def __init__(self, lambd):
super().__init__()
self.lambd = lambd
def forward(self, distances):
return rank_scaled_gaussian(distances, self.lambd)
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()

View File

@ -1,18 +1,13 @@
"""Models based on the GLVQ framework."""
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
@ -137,7 +132,7 @@ class SiameseGLVQ(GLVQ):
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)

View File

@ -2,9 +2,10 @@
import warnings
from prototorch.components import LabeledComponents
from prototorch.modules import KNNC
from ..core.competitions import KNNC
from ..core.components import LabeledComponents
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
from ..utils.utils import parse_data_arg
from .abstract import SupervisedPrototypeModel

View File

@ -1,7 +1,6 @@
"""LVQ models that are optimized using non-gradient methods."""
from prototorch.functions.losses import _get_dp_dm
from ..core.losses import _get_dp_dm
from .abstract import NonGradientMixin
from .glvq import GLVQ

View File

@ -1,13 +1,11 @@
"""Probabilistic GLVQ methods"""
import torch
from prototorch.functions.losses import nllr_loss, rslvq_loss
from prototorch.functions.pooling import (stratified_min_pooling,
stratified_sum_pooling)
from prototorch.functions.transforms import (GaussianPrior,
RankScaledGaussianPrior)
from prototorch.modules import LambdaLayer, LossLayer
from ..core.losses import nllr_loss, rslvq_loss
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
from ..nn.wrappers import LambdaLayer, LossLayer
from .extras import GaussianPrior, RankScaledGaussianPrior
from .glvq import GLVQ, SiameseGMLVQ

View File

@ -2,11 +2,11 @@
import numpy as np
import torch
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import squared_euclidean_distance
from prototorch.modules import LambdaLayer
from prototorch.modules.losses import NeuralGasEnergy
from ..core.competitions import wtac
from ..core.distances import squared_euclidean_distance
from ..core.losses import NeuralGasEnergy
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology