chore: minor changes and version updates

This commit is contained in:
Alexander Engelsberger 2022-05-17 11:56:18 +02:00
parent 6714cb7915
commit 2a85c94b55
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
3 changed files with 75 additions and 58 deletions

View File

@ -1,5 +1,7 @@
"""ProtoTorch CBC example using 2D Iris data."""
import logging
import torch
from matplotlib import pyplot as plt
@ -34,7 +36,7 @@ class VisCBC2D():
self.resolution = 100
self.cmap = "viridis"
def on_epoch_end(self):
def on_train_epoch_end(self):
x_train, y_train = self.x_train, self.y_train
_components = self.model.components_layer._components.detach()
ax = self.fig.gca()
@ -94,5 +96,5 @@ if __name__ == "__main__":
correct += (y_pred.argmax(1) == y).float().sum(0)
acc = 100 * correct / len(train_ds)
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
vis.on_epoch_end()
logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
vis.on_train_epoch_end()

View File

@ -5,8 +5,10 @@ URL:
"""
from __future__ import annotations
import warnings
from typing import Sequence, Union
from typing import Sequence
from sklearn.datasets import (
load_iris,
@ -41,9 +43,9 @@ class Iris(NumpyDataset):
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] = None):
def __init__(self, dims: Sequence[int] | None = None):
x, y = load_iris(return_X_y=True)
if dims:
if dims is not None:
x = x[:, dims]
super().__init__(x, y)
@ -56,15 +58,19 @@ class Blobs(NumpyDataset):
"""
def __init__(self,
num_samples: int = 300,
num_features: int = 2,
seed: Union[None, int] = 0):
x, y = make_blobs(num_samples,
num_features,
centers=None,
random_state=seed,
shuffle=False)
def __init__(
self,
num_samples: int = 300,
num_features: int = 2,
seed: None | int = 0,
):
x, y = make_blobs(
num_samples,
num_features,
centers=None,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)
@ -77,29 +83,33 @@ class Random(NumpyDataset):
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
"""
def __init__(self,
num_samples: int = 300,
num_features: int = 2,
num_classes: int = 2,
num_clusters: int = 2,
num_informative: Union[None, int] = None,
separation: float = 1.0,
seed: Union[None, int] = 0):
def __init__(
self,
num_samples: int = 300,
num_features: int = 2,
num_classes: int = 2,
num_clusters: int = 2,
num_informative: None | int = None,
separation: float = 1.0,
seed: None | int = 0,
):
if not num_informative:
import math
num_informative = math.ceil(math.log2(num_classes * num_clusters))
if num_features < num_informative:
warnings.warn("Generating more features than requested.")
num_features = num_informative
x, y = make_classification(num_samples,
num_features,
n_informative=num_informative,
n_redundant=0,
n_classes=num_classes,
n_clusters_per_class=num_clusters,
class_sep=separation,
random_state=seed,
shuffle=False)
x, y = make_classification(
num_samples,
num_features,
n_informative=num_informative,
n_redundant=0,
n_classes=num_classes,
n_clusters_per_class=num_clusters,
class_sep=separation,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)
@ -113,16 +123,20 @@ class Circles(NumpyDataset):
"""
def __init__(self,
num_samples: int = 300,
noise: float = 0.3,
factor: float = 0.8,
seed: Union[None, int] = 0):
x, y = make_circles(num_samples,
noise=noise,
factor=factor,
random_state=seed,
shuffle=False)
def __init__(
self,
num_samples: int = 300,
noise: float = 0.3,
factor: float = 0.8,
seed: None | int = 0,
):
x, y = make_circles(
num_samples,
noise=noise,
factor=factor,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)
@ -136,12 +150,16 @@ class Moons(NumpyDataset):
"""
def __init__(self,
num_samples: int = 300,
noise: float = 0.3,
seed: Union[None, int] = 0):
x, y = make_moons(num_samples,
noise=noise,
random_state=seed,
shuffle=False)
def __init__(
self,
num_samples: int = 300,
noise: float = 0.3,
seed: None | int = 0,
):
x, y = make_moons(
num_samples,
noise=noise,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)

View File

@ -36,6 +36,7 @@ Description:
are determined by analytic chemistry.
"""
import logging
import os
import numpy as np
@ -81,13 +82,11 @@ class Tecator(ProtoDataset):
if self._check_exists():
return
if self.verbose:
print("Making directories...")
logging.debug("Making directories...")
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
if self.verbose:
print("Downloading...")
logging.debug("Downloading...")
for fileid, md5 in self._resources:
filename = "tecator.npz"
download_file_from_google_drive(fileid,
@ -95,8 +94,7 @@ class Tecator(ProtoDataset):
filename=filename,
md5=md5)
if self.verbose:
print("Processing...")
logging.debug("Processing...")
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
allow_pickle=False) as f:
x_train, y_train = f["x_train"], f["y_train"]
@ -117,5 +115,4 @@ class Tecator(ProtoDataset):
"wb") as f:
torch.save(test_set, f)
if self.verbose:
print("Done!")
logging.debug("Done!")