Compare commits
4 Commits
v0.7.3
...
fix/sklear
Author | SHA1 | Date | |
---|---|---|---|
|
7ae7578845 | ||
|
0649d5bb45 | ||
|
339316aa7e | ||
|
2a85c94b55 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.7.3
|
current_version = 0.7.4
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.7.3"
|
release = "0.7.4"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
"""ProtoTorch CBC example using 2D Iris data."""
|
"""ProtoTorch CBC example using 2D Iris data."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ class VisCBC2D():
|
|||||||
self.resolution = 100
|
self.resolution = 100
|
||||||
self.cmap = "viridis"
|
self.cmap = "viridis"
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_train_epoch_end(self):
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
_components = self.model.components_layer._components.detach()
|
_components = self.model.components_layer._components.detach()
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
@@ -94,5 +96,5 @@ if __name__ == "__main__":
|
|||||||
correct += (y_pred.argmax(1) == y).float().sum(0)
|
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||||
|
|
||||||
acc = 100 * correct / len(train_ds)
|
acc = 100 * correct / len(train_ds)
|
||||||
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||||
vis.on_epoch_end()
|
vis.on_train_epoch_end()
|
||||||
|
@@ -17,7 +17,7 @@ from .core import similarities # noqa: F401
|
|||||||
from .core import transforms # noqa: F401
|
from .core import transforms # noqa: F401
|
||||||
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
__version__ = "0.7.3"
|
__version__ = "0.7.4"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"competitions",
|
"competitions",
|
||||||
|
@@ -38,7 +38,7 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
|||||||
pk = A
|
pk = A
|
||||||
nk = (1 - A) * B
|
nk = (1 - A) * B
|
||||||
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||||
probs = numerator / (pk + nk).sum(1)
|
probs = numerator / ((pk + nk).sum(1) + 1e-8)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
@@ -5,8 +5,10 @@ URL:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Sequence, Union
|
from typing import Sequence
|
||||||
|
|
||||||
from sklearn.datasets import (
|
from sklearn.datasets import (
|
||||||
load_iris,
|
load_iris,
|
||||||
@@ -41,9 +43,9 @@ class Iris(NumpyDataset):
|
|||||||
:param dims: select a subset of dimensions
|
:param dims: select a subset of dimensions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dims: Sequence[int] = None):
|
def __init__(self, dims: Sequence[int] | None = None):
|
||||||
x, y = load_iris(return_X_y=True)
|
x, y = load_iris(return_X_y=True)
|
||||||
if dims:
|
if dims is not None:
|
||||||
x = x[:, dims]
|
x = x[:, dims]
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
@@ -56,15 +58,19 @@ class Blobs(NumpyDataset):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
num_samples: int = 300,
|
self,
|
||||||
num_features: int = 2,
|
num_samples: int = 300,
|
||||||
seed: Union[None, int] = 0):
|
num_features: int = 2,
|
||||||
x, y = make_blobs(num_samples,
|
seed: None | int = 0,
|
||||||
num_features,
|
):
|
||||||
centers=None,
|
x, y = make_blobs(
|
||||||
random_state=seed,
|
num_samples,
|
||||||
shuffle=False)
|
num_features,
|
||||||
|
centers=None,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@@ -77,29 +83,33 @@ class Random(NumpyDataset):
|
|||||||
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
num_samples: int = 300,
|
self,
|
||||||
num_features: int = 2,
|
num_samples: int = 300,
|
||||||
num_classes: int = 2,
|
num_features: int = 2,
|
||||||
num_clusters: int = 2,
|
num_classes: int = 2,
|
||||||
num_informative: Union[None, int] = None,
|
num_clusters: int = 2,
|
||||||
separation: float = 1.0,
|
num_informative: None | int = None,
|
||||||
seed: Union[None, int] = 0):
|
separation: float = 1.0,
|
||||||
|
seed: None | int = 0,
|
||||||
|
):
|
||||||
if not num_informative:
|
if not num_informative:
|
||||||
import math
|
import math
|
||||||
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||||
if num_features < num_informative:
|
if num_features < num_informative:
|
||||||
warnings.warn("Generating more features than requested.")
|
warnings.warn("Generating more features than requested.")
|
||||||
num_features = num_informative
|
num_features = num_informative
|
||||||
x, y = make_classification(num_samples,
|
x, y = make_classification(
|
||||||
num_features,
|
num_samples,
|
||||||
n_informative=num_informative,
|
num_features,
|
||||||
n_redundant=0,
|
n_informative=num_informative,
|
||||||
n_classes=num_classes,
|
n_redundant=0,
|
||||||
n_clusters_per_class=num_clusters,
|
n_classes=num_classes,
|
||||||
class_sep=separation,
|
n_clusters_per_class=num_clusters,
|
||||||
random_state=seed,
|
class_sep=separation,
|
||||||
shuffle=False)
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@@ -113,16 +123,20 @@ class Circles(NumpyDataset):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
num_samples: int = 300,
|
self,
|
||||||
noise: float = 0.3,
|
num_samples: int = 300,
|
||||||
factor: float = 0.8,
|
noise: float = 0.3,
|
||||||
seed: Union[None, int] = 0):
|
factor: float = 0.8,
|
||||||
x, y = make_circles(num_samples,
|
seed: None | int = 0,
|
||||||
noise=noise,
|
):
|
||||||
factor=factor,
|
x, y = make_circles(
|
||||||
random_state=seed,
|
num_samples,
|
||||||
shuffle=False)
|
noise=noise,
|
||||||
|
factor=factor,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@@ -136,12 +150,16 @@ class Moons(NumpyDataset):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
num_samples: int = 300,
|
self,
|
||||||
noise: float = 0.3,
|
num_samples: int = 300,
|
||||||
seed: Union[None, int] = 0):
|
noise: float = 0.3,
|
||||||
x, y = make_moons(num_samples,
|
seed: None | int = 0,
|
||||||
noise=noise,
|
):
|
||||||
random_state=seed,
|
x, y = make_moons(
|
||||||
shuffle=False)
|
num_samples,
|
||||||
|
noise=noise,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
@@ -36,6 +36,7 @@ Description:
|
|||||||
are determined by analytic chemistry.
|
are determined by analytic chemistry.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -81,13 +82,11 @@ class Tecator(ProtoDataset):
|
|||||||
if self._check_exists():
|
if self._check_exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Making directories...")
|
||||||
print("Making directories...")
|
|
||||||
os.makedirs(self.raw_folder, exist_ok=True)
|
os.makedirs(self.raw_folder, exist_ok=True)
|
||||||
os.makedirs(self.processed_folder, exist_ok=True)
|
os.makedirs(self.processed_folder, exist_ok=True)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Downloading...")
|
||||||
print("Downloading...")
|
|
||||||
for fileid, md5 in self._resources:
|
for fileid, md5 in self._resources:
|
||||||
filename = "tecator.npz"
|
filename = "tecator.npz"
|
||||||
download_file_from_google_drive(fileid,
|
download_file_from_google_drive(fileid,
|
||||||
@@ -95,8 +94,7 @@ class Tecator(ProtoDataset):
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
md5=md5)
|
md5=md5)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Processing...")
|
||||||
print("Processing...")
|
|
||||||
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||||
allow_pickle=False) as f:
|
allow_pickle=False) as f:
|
||||||
x_train, y_train = f["x_train"], f["y_train"]
|
x_train, y_train = f["x_train"], f["y_train"]
|
||||||
@@ -117,5 +115,4 @@ class Tecator(ProtoDataset):
|
|||||||
"wb") as f:
|
"wb") as f:
|
||||||
torch.save(test_set, f)
|
torch.save(test_set, f)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Done!")
|
||||||
print("Done!")
|
|
||||||
|
6
setup.py
6
setup.py
@@ -20,9 +20,9 @@ with open("README.md", "r") as fh:
|
|||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"torch>=1.3.1",
|
"torch>=1.3.1",
|
||||||
"torchvision>=0.7.3",
|
"torchvision>=0.7.4",
|
||||||
"numpy>=1.9.1",
|
"numpy>=1.9.1",
|
||||||
"sklearn",
|
"scikit-learn",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
]
|
]
|
||||||
DATASETS = [
|
DATASETS = [
|
||||||
@@ -51,7 +51,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="prototorch",
|
name="prototorch",
|
||||||
version="0.7.3",
|
version="0.7.4",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
|
Reference in New Issue
Block a user