8 Commits

Author SHA1 Message Date
Alexander Engelsberger
09c80e2d54 Merge branch 'master' into kernel_distances 2021-05-11 16:10:56 +02:00
Alexander Engelsberger
bc20acd63b Bump version: 0.4.1 → 0.4.2 2021-05-11 16:08:37 +02:00
Jensun Ravichandran
7bb93f027a Support for unequal prototype distributions 2021-05-11 16:11:11 +02:00
Alexander Engelsberger
65e0637b17 Fix RBF Kernel Dimensions. 2021-04-27 17:58:05 +02:00
Alexander Engelsberger
209f9e641b Fix kernel dimensions. 2021-04-27 16:56:56 +02:00
Alexander Engelsberger
ba537fe1d5 Automatic formatting. 2021-04-27 15:43:10 +02:00
Alexander Engelsberger
b0cd2de18e Batch Kernel. [Ineficient] 2021-04-27 15:38:34 +02:00
Alexander Engelsberger
7d353f5b5a Kernel Distances. 2021-04-27 12:06:15 +02:00
11 changed files with 285 additions and 29 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.1 current_version = 0.4.2
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.4.1" release = "0.4.2"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -1,7 +1,7 @@
"""ProtoTorch package.""" """ProtoTorch package."""
# Core Setup # Core Setup
__version__ = "0.4.1" __version__ = "0.4.2"
__all_core__ = [ __all_core__ = [
"datasets", "datasets",

View File

@@ -4,8 +4,10 @@ import warnings
from typing import Tuple from typing import Tuple
import torch import torch
from prototorch.components.initializers import (ComponentsInitializer, from prototorch.components.initializers import (ClassAwareInitializer,
EqualLabelInitializer, ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer) ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@@ -30,12 +32,15 @@ class Components(torch.nn.Module):
else: else:
self._initialize_components(number_of_components, initializer) self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer): def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer): if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \ emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \ f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead." f"You have provided: {initializer=} instead."
raise TypeError(emsg) raise TypeError(emsg)
def _initialize_components(self, number_of_components, initializer):
self._precheck_initializer(initializer)
self._components = Parameter( self._components = Parameter(
initializer.generate(number_of_components)) initializer.generate(number_of_components))
@@ -57,7 +62,7 @@ class LabeledComponents(Components):
Every Component has a label assigned. Every Component has a label assigned.
""" """
def __init__(self, def __init__(self,
labels=None, distribution=None,
initializer=None, initializer=None,
*, *,
initialized_components=None): initialized_components=None):
@@ -65,15 +70,27 @@ class LabeledComponents(Components):
super().__init__(initialized_components=initialized_components[0]) super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1] self._labels = initialized_components[1]
else: else:
self._initialize_labels(labels) self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels), super().__init__(number_of_components=len(self._labels),
initializer=initializer) initializer=initializer)
def _initialize_labels(self, labels): def _initialize_components(self, number_of_components, initializer):
if type(labels) == tuple: if isinstance(initializer, ClassAwareInitializer):
num_classes, prototypes_per_class = labels self._precheck_initializer(initializer)
labels = EqualLabelInitializer(num_classes, prototypes_per_class) self._components = Parameter(
initializer.generate(number_of_components, self.distribution))
else:
super()._initialize_components(self, number_of_components,
initializer)
def _initialize_labels(self, distribution):
if type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution
self._labels = labels.generate() self._labels = labels.generate()
@property @property

View File

@@ -1,6 +1,7 @@
"""ProtoTroch Initializers.""" """ProtoTroch Initializers."""
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from itertools import chain
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
@@ -91,6 +92,15 @@ class ClassAwareInitializer(ComponentsInitializer):
self.clabels = torch.unique(self.labels) self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
return torch.vstack(samples_list)
class StratifiedMeanInitializer(ClassAwareInitializer): class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg): def __init__(self, arg):
@@ -102,10 +112,9 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class_initializer = MeanInitializer(class_data) class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer) self.initializers.append(class_initializer)
def generate(self, length): def generate(self, length, dist=[]):
per_class = length // self.num_classes samples = self._get_samples_from_initializer(length, dist)
samples_list = [init.generate(per_class) for init in self.initializers] return samples
return torch.vstack(samples_list)
class StratifiedSelectionInitializer(ClassAwareInitializer): class StratifiedSelectionInitializer(ClassAwareInitializer):
@@ -126,10 +135,8 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
mask = torch.bernoulli(n1) - torch.bernoulli(n2) mask = torch.bernoulli(n1) - torch.bernoulli(n2)
return x + (self.noise * mask) return x + (self.noise * mask)
def generate(self, length): def generate(self, length, dist=[]):
per_class = length // self.num_classes samples = self._get_samples_from_initializer(length, dist)
samples_list = [init.generate(per_class) for init in self.initializers]
samples = torch.vstack(samples_list)
if self.noise is not None: if self.noise is not None:
# samples = self.add_noise(samples) # samples = self.add_noise(samples)
samples = samples + self.noise samples = samples + self.noise
@@ -142,11 +149,29 @@ class LabelsInitializer:
raise NotImplementedError("Subclasses should implement this!") raise NotImplementedError("Subclasses should implement this!")
class EqualLabelInitializer(LabelsInitializer): class UnequalLabelsInitializer(LabelsInitializer):
def __init__(self, dist):
self.dist = dist
@property
def distribution(self):
return self.dist
def generate(self):
clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
return torch.tensor(labels)
class EqualLabelsInitializer(LabelsInitializer):
def __init__(self, classes, per_class): def __init__(self, classes, per_class):
self.classes = classes self.classes = classes
self.per_class = per_class self.per_class = per_class
@property
def distribution(self):
return self.classes * [self.per_class]
def generate(self): def generate(self):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()

View File

@@ -3,8 +3,11 @@
import numpy as np import numpy as np
import torch import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape, from prototorch.functions.helper import (
equal_int_shape) _check_shapes,
_int_and_mixed_shape,
equal_int_shape,
)
def squared_euclidean_distance(x, y): def squared_euclidean_distance(x, y):
@@ -261,5 +264,86 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
return diss.permute([1, 0, 2]).squeeze(-1) return diss.permute([1, 0, 2]).squeeze(-1)
class KernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
return self._single_call(x_batch, y_batch)
def _single_call(self, x, y):
remove_dims = []
if len(x.shape) == 1:
x = x.unsqueeze(0)
remove_dims.append(0)
if len(y.shape) == 1:
y = y.unsqueeze(0)
remove_dims.append(-1)
output = self.kernel_fn(x, x).diag().unsqueeze(1) - 2 * self.kernel_fn(
x, y) + self.kernel_fn(y, y).diag()
for dim in remove_dims:
output.squeeze_(dim)
return torch.sqrt(output)
class BatchKernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
remove_dims = 0
# Extend Single inputs
if len(x_batch.shape) == 1:
x_batch = x_batch.unsqueeze(0)
remove_dims += 1
if len(y_batch.shape) == 1:
y_batch = y_batch.unsqueeze(0)
remove_dims += 1
# Loop over batches
output = torch.FloatTensor(len(x_batch), len(y_batch))
for i, x in enumerate(x_batch):
for j, y in enumerate(y_batch):
output[i][j] = self._single_call(x, y)
for _ in range(remove_dims):
output.squeeze_(0)
return output
def _single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
squared_distance = kappa_xx - 2 * kappa_xy + kappa_yy
return torch.sqrt(squared_distance)
class SquaredKernelDistance(KernelDistance):
r"""Squared Kernel Distance
Kernel distance without final squareroot.
"""
def single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
return kappa_xx - 2 * kappa_xy + kappa_yy
# Aliases # Aliases
sed = squared_euclidean_distance sed = squared_euclidean_distance

View File

@@ -0,0 +1,28 @@
"""
Experimental Kernels
"""
import torch
class ExplicitKernel:
def __init__(self, projection=torch.nn.Identity()):
self.projection = projection
def __call__(self, x, y):
return self.projection(x) @ self.projection(y).T
class RadialBasisFunctionKernel:
def __init__(self, sigma) -> None:
self.s2 = sigma * sigma
def __call__(self, x, y):
remove_dim = False
if len(x.shape) > 1:
x = x.unsqueeze(1)
remove_dim = True
output = torch.exp(-torch.sum((x - y)**2, dim=-1) / (2 * self.s2))
if remove_dim:
output = output.squeeze(1)
return output

View File

@@ -1,8 +1,7 @@
import torch import torch
from torch import nn from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix, from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D

View File

@@ -42,7 +42,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name="prototorch", name="prototorch",
version="0.4.1", version="0.4.2",
description="Highly extensible, GPU-supported " description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox " "Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.", "built using PyTorch and its nn API.",

View File

@@ -5,8 +5,13 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from prototorch.functions import (activations, competitions, distances, from prototorch.functions import (
initializers, losses) activations,
competitions,
distances,
initializers,
losses,
)
class TestActivations(unittest.TestCase): class TestActivations(unittest.TestCase):

98
tests/test_kernels.py Normal file
View File

@@ -0,0 +1,98 @@
"""ProtoTorch kernels test suite."""
import unittest
import numpy as np
import torch
from prototorch.functions.distances import KernelDistance
from prototorch.functions.kernels import ExplicitKernel, RadialBasisFunctionKernel
class TestExplicitKernel(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
def test_single_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
class TestRadialBasisFunctionKernel(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
def test_single_values(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
class TestKernelDistance(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
self.kernel = ExplicitKernel()
def test_single_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))