chore: minor changes and version updates

This commit is contained in:
Alexander Engelsberger 2022-05-17 11:56:18 +02:00
parent 6714cb7915
commit 2a85c94b55
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
3 changed files with 75 additions and 58 deletions

View File

@ -1,5 +1,7 @@
"""ProtoTorch CBC example using 2D Iris data.""" """ProtoTorch CBC example using 2D Iris data."""
import logging
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
@ -34,7 +36,7 @@ class VisCBC2D():
self.resolution = 100 self.resolution = 100
self.cmap = "viridis" self.cmap = "viridis"
def on_epoch_end(self): def on_train_epoch_end(self):
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
_components = self.model.components_layer._components.detach() _components = self.model.components_layer._components.detach()
ax = self.fig.gca() ax = self.fig.gca()
@ -94,5 +96,5 @@ if __name__ == "__main__":
correct += (y_pred.argmax(1) == y).float().sum(0) correct += (y_pred.argmax(1) == y).float().sum(0)
acc = 100 * correct / len(train_ds) acc = 100 * correct / len(train_ds)
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%") logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
vis.on_epoch_end() vis.on_train_epoch_end()

View File

@ -5,8 +5,10 @@ URL:
""" """
from __future__ import annotations
import warnings import warnings
from typing import Sequence, Union from typing import Sequence
from sklearn.datasets import ( from sklearn.datasets import (
load_iris, load_iris,
@ -41,9 +43,9 @@ class Iris(NumpyDataset):
:param dims: select a subset of dimensions :param dims: select a subset of dimensions
""" """
def __init__(self, dims: Sequence[int] = None): def __init__(self, dims: Sequence[int] | None = None):
x, y = load_iris(return_X_y=True) x, y = load_iris(return_X_y=True)
if dims: if dims is not None:
x = x[:, dims] x = x[:, dims]
super().__init__(x, y) super().__init__(x, y)
@ -56,15 +58,19 @@ class Blobs(NumpyDataset):
""" """
def __init__(self, def __init__(
self,
num_samples: int = 300, num_samples: int = 300,
num_features: int = 2, num_features: int = 2,
seed: Union[None, int] = 0): seed: None | int = 0,
x, y = make_blobs(num_samples, ):
x, y = make_blobs(
num_samples,
num_features, num_features,
centers=None, centers=None,
random_state=seed, random_state=seed,
shuffle=False) shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)
@ -77,21 +83,24 @@ class Random(NumpyDataset):
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy. Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
""" """
def __init__(self, def __init__(
self,
num_samples: int = 300, num_samples: int = 300,
num_features: int = 2, num_features: int = 2,
num_classes: int = 2, num_classes: int = 2,
num_clusters: int = 2, num_clusters: int = 2,
num_informative: Union[None, int] = None, num_informative: None | int = None,
separation: float = 1.0, separation: float = 1.0,
seed: Union[None, int] = 0): seed: None | int = 0,
):
if not num_informative: if not num_informative:
import math import math
num_informative = math.ceil(math.log2(num_classes * num_clusters)) num_informative = math.ceil(math.log2(num_classes * num_clusters))
if num_features < num_informative: if num_features < num_informative:
warnings.warn("Generating more features than requested.") warnings.warn("Generating more features than requested.")
num_features = num_informative num_features = num_informative
x, y = make_classification(num_samples, x, y = make_classification(
num_samples,
num_features, num_features,
n_informative=num_informative, n_informative=num_informative,
n_redundant=0, n_redundant=0,
@ -99,7 +108,8 @@ class Random(NumpyDataset):
n_clusters_per_class=num_clusters, n_clusters_per_class=num_clusters,
class_sep=separation, class_sep=separation,
random_state=seed, random_state=seed,
shuffle=False) shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)
@ -113,16 +123,20 @@ class Circles(NumpyDataset):
""" """
def __init__(self, def __init__(
self,
num_samples: int = 300, num_samples: int = 300,
noise: float = 0.3, noise: float = 0.3,
factor: float = 0.8, factor: float = 0.8,
seed: Union[None, int] = 0): seed: None | int = 0,
x, y = make_circles(num_samples, ):
x, y = make_circles(
num_samples,
noise=noise, noise=noise,
factor=factor, factor=factor,
random_state=seed, random_state=seed,
shuffle=False) shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)
@ -136,12 +150,16 @@ class Moons(NumpyDataset):
""" """
def __init__(self, def __init__(
self,
num_samples: int = 300, num_samples: int = 300,
noise: float = 0.3, noise: float = 0.3,
seed: Union[None, int] = 0): seed: None | int = 0,
x, y = make_moons(num_samples, ):
x, y = make_moons(
num_samples,
noise=noise, noise=noise,
random_state=seed, random_state=seed,
shuffle=False) shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)

View File

@ -36,6 +36,7 @@ Description:
are determined by analytic chemistry. are determined by analytic chemistry.
""" """
import logging
import os import os
import numpy as np import numpy as np
@ -81,13 +82,11 @@ class Tecator(ProtoDataset):
if self._check_exists(): if self._check_exists():
return return
if self.verbose: logging.debug("Making directories...")
print("Making directories...")
os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True) os.makedirs(self.processed_folder, exist_ok=True)
if self.verbose: logging.debug("Downloading...")
print("Downloading...")
for fileid, md5 in self._resources: for fileid, md5 in self._resources:
filename = "tecator.npz" filename = "tecator.npz"
download_file_from_google_drive(fileid, download_file_from_google_drive(fileid,
@ -95,8 +94,7 @@ class Tecator(ProtoDataset):
filename=filename, filename=filename,
md5=md5) md5=md5)
if self.verbose: logging.debug("Processing...")
print("Processing...")
with np.load(os.path.join(self.raw_folder, "tecator.npz"), with np.load(os.path.join(self.raw_folder, "tecator.npz"),
allow_pickle=False) as f: allow_pickle=False) as f:
x_train, y_train = f["x_train"], f["y_train"] x_train, y_train = f["x_train"], f["y_train"]
@ -117,5 +115,4 @@ class Tecator(ProtoDataset):
"wb") as f: "wb") as f:
torch.save(test_set, f) torch.save(test_set, f)
if self.verbose: logging.debug("Done!")
print("Done!")