6 Commits

Author SHA1 Message Date
Alexander Engelsberger
09256956f3 Bump version: 0.4.2 → 0.4.3 2021-05-11 17:04:08 +02:00
Jensun Ravichandran
0ca90fdcee Merge branch 'dev' of github.com:si-cim/prototorch into dev 2021-05-11 17:07:04 +02:00
Jensun Ravichandran
be21412f8a Add thin wrapper for the Iris dataset 2021-05-11 17:06:41 +02:00
Jensun Ravichandran
ae6bc47f87 [BUGFIX] Fix knnc 2021-05-11 17:06:27 +02:00
Jensun Ravichandran
7bb93f027a Support for unequal prototype distributions 2021-05-11 16:11:11 +02:00
Alexander Engelsberger
bc20acd63b Bump version: 0.4.1 → 0.4.2 2021-05-11 16:08:37 +02:00
8 changed files with 73 additions and 36 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.1
current_version = 0.4.3
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.4.1"
release = "0.4.3"
# -- General configuration ---------------------------------------------------

View File

@@ -1,7 +1,7 @@
"""ProtoTorch package."""
# Core Setup
__version__ = "0.4.1"
__version__ = "0.4.3"
__all_core__ = [
"datasets",

View File

@@ -4,8 +4,10 @@ import warnings
from typing import Tuple
import torch
from prototorch.components.initializers import (ComponentsInitializer,
EqualLabelInitializer,
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter
@@ -30,12 +32,15 @@ class Components(torch.nn.Module):
else:
self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer):
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
def _initialize_components(self, number_of_components, initializer):
self._precheck_initializer(initializer)
self._components = Parameter(
initializer.generate(number_of_components))
@@ -57,23 +62,36 @@ class LabeledComponents(Components):
Every Component has a label assigned.
"""
def __init__(self,
labels=None,
distribution=None,
initializer=None,
*,
initialized_components=None):
if initialized_components is not None:
super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1]
components, component_labels = initialized_components
super().__init__(initialized_components=components)
self._labels = component_labels
else:
self._initialize_labels(labels)
self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels),
initializer=initializer)
def _initialize_labels(self, labels):
if type(labels) == tuple:
num_classes, prototypes_per_class = labels
labels = EqualLabelInitializer(num_classes, prototypes_per_class)
def _initialize_components(self, number_of_components, initializer):
if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer)
self._components = Parameter(
initializer.generate(number_of_components, self.distribution))
else:
super()._initialize_components(self, number_of_components,
initializer)
def _initialize_labels(self, distribution):
if type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution
self._labels = labels.generate()
@property

View File

@@ -1,6 +1,7 @@
"""ProtoTroch Initializers."""
import warnings
from collections.abc import Iterable
from itertools import chain
import torch
from torch.utils.data import DataLoader, Dataset
@@ -91,6 +92,15 @@ class ClassAwareInitializer(ComponentsInitializer):
self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
return torch.vstack(samples_list)
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg):
@@ -102,10 +112,9 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer)
def generate(self, length):
per_class = length // self.num_classes
samples_list = [init.generate(per_class) for init in self.initializers]
return torch.vstack(samples_list)
def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist)
return samples
class StratifiedSelectionInitializer(ClassAwareInitializer):
@@ -126,10 +135,8 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
return x + (self.noise * mask)
def generate(self, length):
per_class = length // self.num_classes
samples_list = [init.generate(per_class) for init in self.initializers]
samples = torch.vstack(samples_list)
def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist)
if self.noise is not None:
# samples = self.add_noise(samples)
samples = samples + self.noise
@@ -142,11 +149,29 @@ class LabelsInitializer:
raise NotImplementedError("Subclasses should implement this!")
class EqualLabelInitializer(LabelsInitializer):
class UnequalLabelsInitializer(LabelsInitializer):
def __init__(self, dist):
self.dist = dist
@property
def distribution(self):
return self.dist
def generate(self):
clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
return torch.tensor(labels)
class EqualLabelsInitializer(LabelsInitializer):
def __init__(self, classes, per_class):
self.classes = classes
self.per_class = per_class
@property
def distribution(self):
return self.classes * [self.per_class]
def generate(self):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()

View File

@@ -1,11 +1,6 @@
"""ProtoTorch datasets."""
from .abstract import NumpyDataset
from .iris import Iris
from .spiral import Spiral
from .tecator import Tecator
__all__ = [
"NumpyDataset",
"Spiral",
"Tecator",
]

View File

@@ -3,7 +3,6 @@
import torch
# @torch.jit.script
def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0]
@@ -31,15 +30,14 @@ def stratified_min(distances, labels):
return winning_distances.T # return with `batch_size` first
# @torch.jit.script
def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze()
return winning_labels
# @torch.jit.script
def knnc(distances, labels, k):
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
winning_labels = labels[winning_indices].squeeze()
def knnc(distances, labels, k=1):
winning_indices = torch.topk(-distances, k=k, dim=1).indices
winning_labels = torch.mode(labels[winning_indices].squeeze(),
dim=1).values
return winning_labels

View File

@@ -23,6 +23,7 @@ INSTALL_REQUIRES = [
]
DATASETS = [
"requests",
"sklearn",
"tqdm",
]
DEV = ["bumpversion"]
@@ -42,7 +43,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup(
name="prototorch",
version="0.4.1",
version="0.4.3",
description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.",