chore: replace relative imports
This commit is contained in:
parent
29ee326b85
commit
bccef8bef0
@ -3,13 +3,15 @@
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
from prototorch.core.competitions import WTAC
|
||||||
from ..core.competitions import WTAC
|
from prototorch.core.components import Components, LabeledComponents
|
||||||
from ..core.components import Components, LabeledComponents
|
from prototorch.core.distances import euclidean_distance
|
||||||
from ..core.distances import euclidean_distance
|
from prototorch.core.initializers import (
|
||||||
from ..core.initializers import LabelsInitializer, ZerosCompInitializer
|
LabelsInitializer,
|
||||||
from ..core.pooling import stratified_min_pooling
|
ZerosCompInitializer,
|
||||||
from ..nn.wrappers import LambdaLayer
|
)
|
||||||
|
from prototorch.core.pooling import stratified_min_pooling
|
||||||
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchBolt(pl.LightningModule):
|
class ProtoTorchBolt(pl.LightningModule):
|
||||||
|
@ -4,9 +4,9 @@ import logging
|
|||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.core.components import Components
|
||||||
|
from prototorch.core.initializers import LiteralCompInitializer
|
||||||
|
|
||||||
from ..core.components import Components
|
|
||||||
from ..core.initializers import LiteralCompInitializer
|
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
from prototorch.core.competitions import CBCC
|
||||||
|
from prototorch.core.components import ReasoningComponents
|
||||||
|
from prototorch.core.initializers import RandomReasoningsInitializer
|
||||||
|
from prototorch.core.losses import MarginLoss
|
||||||
|
from prototorch.core.similarities import euclidean_similarity
|
||||||
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
from ..core.competitions import CBCC
|
|
||||||
from ..core.components import ReasoningComponents
|
|
||||||
from ..core.initializers import RandomReasoningsInitializer
|
|
||||||
from ..core.losses import MarginLoss
|
|
||||||
from ..core.similarities import euclidean_similarity
|
|
||||||
from ..nn.wrappers import LambdaLayer
|
|
||||||
from .abstract import ImagePrototypesMixin
|
from .abstract import ImagePrototypesMixin
|
||||||
from .glvq import SiameseGLVQ
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
@ -5,8 +5,7 @@ Modules not yet available in prototorch go here temporarily.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.core.similarities import gaussian
|
||||||
from ..core.similarities import gaussian
|
|
||||||
|
|
||||||
|
|
||||||
def rank_scaled_gaussian(distances, lambd):
|
def rank_scaled_gaussian(distances, lambd):
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
"""Models based on the GLVQ framework."""
|
"""Models based on the GLVQ framework."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from prototorch.core.competitions import wtac
|
||||||
|
from prototorch.core.distances import (
|
||||||
from ..core.competitions import wtac
|
|
||||||
from ..core.distances import (
|
|
||||||
lomega_distance,
|
lomega_distance,
|
||||||
omega_distance,
|
omega_distance,
|
||||||
squared_euclidean_distance,
|
squared_euclidean_distance,
|
||||||
)
|
)
|
||||||
from ..core.initializers import EyeLinearTransformInitializer
|
from prototorch.core.initializers import EyeLinearTransformInitializer
|
||||||
from ..core.losses import (
|
from prototorch.core.losses import (
|
||||||
GLVQLoss,
|
GLVQLoss,
|
||||||
lvq1_loss,
|
lvq1_loss,
|
||||||
lvq21_loss,
|
lvq21_loss,
|
||||||
)
|
)
|
||||||
from ..core.transforms import LinearTransform
|
from prototorch.core.transforms import LinearTransform
|
||||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
from .extras import ltangent_distance, orthogonalization
|
from .extras import ltangent_distance, orthogonalization
|
||||||
|
|
||||||
|
@ -2,13 +2,14 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ..core.competitions import KNNC
|
from prototorch.core.competitions import KNNC
|
||||||
from ..core.components import LabeledComponents
|
from prototorch.core.components import LabeledComponents
|
||||||
from ..core.initializers import (
|
from prototorch.core.initializers import (
|
||||||
LiteralCompInitializer,
|
LiteralCompInitializer,
|
||||||
LiteralLabelsInitializer,
|
LiteralLabelsInitializer,
|
||||||
)
|
)
|
||||||
from ..utils.utils import parse_data_arg
|
from prototorch.utils.utils import parse_data_arg
|
||||||
|
|
||||||
from .abstract import SupervisedPrototypeModel
|
from .abstract import SupervisedPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
"""LVQ models that are optimized using non-gradient methods."""
|
"""LVQ models that are optimized using non-gradient methods."""
|
||||||
|
|
||||||
from ..core.losses import _get_dp_dm
|
from prototorch.core.losses import _get_dp_dm
|
||||||
from ..nn.activations import get_activation
|
from prototorch.nn.activations import get_activation
|
||||||
from ..nn.wrappers import LambdaLayer
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
from .abstract import NonGradientMixin
|
from .abstract import NonGradientMixin
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
"""Probabilistic GLVQ methods"""
|
"""Probabilistic GLVQ methods"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.core.losses import nllr_loss, rslvq_loss
|
||||||
|
from prototorch.core.pooling import (
|
||||||
|
stratified_min_pooling,
|
||||||
|
stratified_sum_pooling,
|
||||||
|
)
|
||||||
|
from prototorch.nn.wrappers import LossLayer
|
||||||
|
|
||||||
from ..core.losses import nllr_loss, rslvq_loss
|
|
||||||
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
|
|
||||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
|
||||||
from .extras import GaussianPrior, RankScaledGaussianPrior
|
from .extras import GaussianPrior, RankScaledGaussianPrior
|
||||||
from .glvq import GLVQ, SiameseGMLVQ
|
from .glvq import GLVQ, SiameseGMLVQ
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.core.competitions import wtac
|
||||||
|
from prototorch.core.distances import squared_euclidean_distance
|
||||||
|
from prototorch.core.losses import NeuralGasEnergy
|
||||||
|
|
||||||
from ..core.competitions import wtac
|
|
||||||
from ..core.distances import squared_euclidean_distance
|
|
||||||
from ..core.losses import NeuralGasEnergy
|
|
||||||
from ..nn.wrappers import LambdaLayer
|
|
||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
|
||||||
from .callbacks import GNGCallback
|
from .callbacks import GNGCallback
|
||||||
from .extras import ConnectionTopology
|
from .extras import ConnectionTopology
|
||||||
|
@ -5,11 +5,10 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from prototorch.utils.colors import get_colors, get_legend_handles
|
||||||
|
from prototorch.utils.utils import mesh2d
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from ..utils.colors import get_colors, get_legend_handles
|
|
||||||
from ..utils.utils import mesh2d
|
|
||||||
|
|
||||||
|
|
||||||
class Vis2DAbstract(pl.Callback):
|
class Vis2DAbstract(pl.Callback):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user