2 Commits

Author SHA1 Message Date
Alexander Engelsberger
bc20acd63b Bump version: 0.4.1 → 0.4.2 2021-05-11 16:08:37 +02:00
Jensun Ravichandran
7bb93f027a Support for unequal prototype distributions 2021-05-11 16:11:11 +02:00
6 changed files with 64 additions and 22 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.1 current_version = 0.4.2
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.4.1" release = "0.4.2"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -1,7 +1,7 @@
"""ProtoTorch package.""" """ProtoTorch package."""
# Core Setup # Core Setup
__version__ = "0.4.1" __version__ = "0.4.2"
__all_core__ = [ __all_core__ = [
"datasets", "datasets",

View File

@@ -4,8 +4,10 @@ import warnings
from typing import Tuple from typing import Tuple
import torch import torch
from prototorch.components.initializers import (ComponentsInitializer, from prototorch.components.initializers import (ClassAwareInitializer,
EqualLabelInitializer, ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer) ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@@ -30,12 +32,15 @@ class Components(torch.nn.Module):
else: else:
self._initialize_components(number_of_components, initializer) self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer): def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer): if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \ emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \ f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead." f"You have provided: {initializer=} instead."
raise TypeError(emsg) raise TypeError(emsg)
def _initialize_components(self, number_of_components, initializer):
self._precheck_initializer(initializer)
self._components = Parameter( self._components = Parameter(
initializer.generate(number_of_components)) initializer.generate(number_of_components))
@@ -57,7 +62,7 @@ class LabeledComponents(Components):
Every Component has a label assigned. Every Component has a label assigned.
""" """
def __init__(self, def __init__(self,
labels=None, distribution=None,
initializer=None, initializer=None,
*, *,
initialized_components=None): initialized_components=None):
@@ -65,15 +70,27 @@ class LabeledComponents(Components):
super().__init__(initialized_components=initialized_components[0]) super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1] self._labels = initialized_components[1]
else: else:
self._initialize_labels(labels) self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels), super().__init__(number_of_components=len(self._labels),
initializer=initializer) initializer=initializer)
def _initialize_labels(self, labels): def _initialize_components(self, number_of_components, initializer):
if type(labels) == tuple: if isinstance(initializer, ClassAwareInitializer):
num_classes, prototypes_per_class = labels self._precheck_initializer(initializer)
labels = EqualLabelInitializer(num_classes, prototypes_per_class) self._components = Parameter(
initializer.generate(number_of_components, self.distribution))
else:
super()._initialize_components(self, number_of_components,
initializer)
def _initialize_labels(self, distribution):
if type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution
self._labels = labels.generate() self._labels = labels.generate()
@property @property

View File

@@ -1,6 +1,7 @@
"""ProtoTroch Initializers.""" """ProtoTroch Initializers."""
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from itertools import chain
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
@@ -91,6 +92,15 @@ class ClassAwareInitializer(ComponentsInitializer):
self.clabels = torch.unique(self.labels) self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
return torch.vstack(samples_list)
class StratifiedMeanInitializer(ClassAwareInitializer): class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg): def __init__(self, arg):
@@ -102,10 +112,9 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class_initializer = MeanInitializer(class_data) class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer) self.initializers.append(class_initializer)
def generate(self, length): def generate(self, length, dist=[]):
per_class = length // self.num_classes samples = self._get_samples_from_initializer(length, dist)
samples_list = [init.generate(per_class) for init in self.initializers] return samples
return torch.vstack(samples_list)
class StratifiedSelectionInitializer(ClassAwareInitializer): class StratifiedSelectionInitializer(ClassAwareInitializer):
@@ -126,10 +135,8 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
mask = torch.bernoulli(n1) - torch.bernoulli(n2) mask = torch.bernoulli(n1) - torch.bernoulli(n2)
return x + (self.noise * mask) return x + (self.noise * mask)
def generate(self, length): def generate(self, length, dist=[]):
per_class = length // self.num_classes samples = self._get_samples_from_initializer(length, dist)
samples_list = [init.generate(per_class) for init in self.initializers]
samples = torch.vstack(samples_list)
if self.noise is not None: if self.noise is not None:
# samples = self.add_noise(samples) # samples = self.add_noise(samples)
samples = samples + self.noise samples = samples + self.noise
@@ -142,11 +149,29 @@ class LabelsInitializer:
raise NotImplementedError("Subclasses should implement this!") raise NotImplementedError("Subclasses should implement this!")
class EqualLabelInitializer(LabelsInitializer): class UnequalLabelsInitializer(LabelsInitializer):
def __init__(self, dist):
self.dist = dist
@property
def distribution(self):
return self.dist
def generate(self):
clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
return torch.tensor(labels)
class EqualLabelsInitializer(LabelsInitializer):
def __init__(self, classes, per_class): def __init__(self, classes, per_class):
self.classes = classes self.classes = classes
self.per_class = per_class self.per_class = per_class
@property
def distribution(self):
return self.classes * [self.per_class]
def generate(self): def generate(self):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()

View File

@@ -42,7 +42,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name="prototorch", name="prototorch",
version="0.4.1", version="0.4.2",
description="Highly extensible, GPU-supported " description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox " "Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.", "built using PyTorch and its nn API.",