4 Commits

Author SHA1 Message Date
Alexander Engelsberger
09256956f3 Bump version: 0.4.2 → 0.4.3 2021-05-11 17:04:08 +02:00
Jensun Ravichandran
0ca90fdcee Merge branch 'dev' of github.com:si-cim/prototorch into dev 2021-05-11 17:07:04 +02:00
Jensun Ravichandran
be21412f8a Add thin wrapper for the Iris dataset 2021-05-11 17:06:41 +02:00
Jensun Ravichandran
ae6bc47f87 [BUGFIX] Fix knnc 2021-05-11 17:06:27 +02:00
7 changed files with 13 additions and 18 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.2 current_version = 0.4.3
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.4.2" release = "0.4.3"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -1,7 +1,7 @@
"""ProtoTorch package.""" """ProtoTorch package."""
# Core Setup # Core Setup
__version__ = "0.4.2" __version__ = "0.4.3"
__all_core__ = [ __all_core__ = [
"datasets", "datasets",

View File

@@ -67,8 +67,9 @@ class LabeledComponents(Components):
*, *,
initialized_components=None): initialized_components=None):
if initialized_components is not None: if initialized_components is not None:
super().__init__(initialized_components=initialized_components[0]) components, component_labels = initialized_components
self._labels = initialized_components[1] super().__init__(initialized_components=components)
self._labels = component_labels
else: else:
self._initialize_labels(distribution) self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels), super().__init__(number_of_components=len(self._labels),

View File

@@ -1,11 +1,6 @@
"""ProtoTorch datasets.""" """ProtoTorch datasets."""
from .abstract import NumpyDataset from .abstract import NumpyDataset
from .iris import Iris
from .spiral import Spiral from .spiral import Spiral
from .tecator import Tecator from .tecator import Tecator
__all__ = [
"NumpyDataset",
"Spiral",
"Tecator",
]

View File

@@ -3,7 +3,6 @@
import torch import torch
# @torch.jit.script
def stratified_min(distances, labels): def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0) clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0] nclasses = clabels.size()[0]
@@ -31,15 +30,14 @@ def stratified_min(distances, labels):
return winning_distances.T # return with `batch_size` first return winning_distances.T # return with `batch_size` first
# @torch.jit.script
def wtac(distances, labels): def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels
# @torch.jit.script def knnc(distances, labels, k=1):
def knnc(distances, labels, k): winning_indices = torch.topk(-distances, k=k, dim=1).indices
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices winning_labels = torch.mode(labels[winning_indices].squeeze(),
winning_labels = labels[winning_indices].squeeze() dim=1).values
return winning_labels return winning_labels

View File

@@ -23,6 +23,7 @@ INSTALL_REQUIRES = [
] ]
DATASETS = [ DATASETS = [
"requests", "requests",
"sklearn",
"tqdm", "tqdm",
] ]
DEV = ["bumpversion"] DEV = ["bumpversion"]
@@ -42,7 +43,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name="prototorch", name="prototorch",
version="0.4.2", version="0.4.3",
description="Highly extensible, GPU-supported " description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox " "Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.", "built using PyTorch and its nn API.",