chore: minor changes and version updates
This commit is contained in:
parent
6714cb7915
commit
2a85c94b55
@ -1,5 +1,7 @@
|
||||
"""ProtoTorch CBC example using 2D Iris data."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
@ -34,7 +36,7 @@ class VisCBC2D():
|
||||
self.resolution = 100
|
||||
self.cmap = "viridis"
|
||||
|
||||
def on_epoch_end(self):
|
||||
def on_train_epoch_end(self):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
_components = self.model.components_layer._components.detach()
|
||||
ax = self.fig.gca()
|
||||
@ -94,5 +96,5 @@ if __name__ == "__main__":
|
||||
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||
|
||||
acc = 100 * correct / len(train_ds)
|
||||
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||
vis.on_epoch_end()
|
||||
logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||
vis.on_train_epoch_end()
|
||||
|
@ -5,8 +5,10 @@ URL:
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Sequence, Union
|
||||
from typing import Sequence
|
||||
|
||||
from sklearn.datasets import (
|
||||
load_iris,
|
||||
@ -41,9 +43,9 @@ class Iris(NumpyDataset):
|
||||
:param dims: select a subset of dimensions
|
||||
"""
|
||||
|
||||
def __init__(self, dims: Sequence[int] = None):
|
||||
def __init__(self, dims: Sequence[int] | None = None):
|
||||
x, y = load_iris(return_X_y=True)
|
||||
if dims:
|
||||
if dims is not None:
|
||||
x = x[:, dims]
|
||||
super().__init__(x, y)
|
||||
|
||||
@ -56,15 +58,19 @@ class Blobs(NumpyDataset):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
num_samples: int = 300,
|
||||
num_features: int = 2,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_blobs(num_samples,
|
||||
seed: None | int = 0,
|
||||
):
|
||||
x, y = make_blobs(
|
||||
num_samples,
|
||||
num_features,
|
||||
centers=None,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
shuffle=False,
|
||||
)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
@ -77,21 +83,24 @@ class Random(NumpyDataset):
|
||||
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
num_samples: int = 300,
|
||||
num_features: int = 2,
|
||||
num_classes: int = 2,
|
||||
num_clusters: int = 2,
|
||||
num_informative: Union[None, int] = None,
|
||||
num_informative: None | int = None,
|
||||
separation: float = 1.0,
|
||||
seed: Union[None, int] = 0):
|
||||
seed: None | int = 0,
|
||||
):
|
||||
if not num_informative:
|
||||
import math
|
||||
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||
if num_features < num_informative:
|
||||
warnings.warn("Generating more features than requested.")
|
||||
num_features = num_informative
|
||||
x, y = make_classification(num_samples,
|
||||
x, y = make_classification(
|
||||
num_samples,
|
||||
num_features,
|
||||
n_informative=num_informative,
|
||||
n_redundant=0,
|
||||
@ -99,7 +108,8 @@ class Random(NumpyDataset):
|
||||
n_clusters_per_class=num_clusters,
|
||||
class_sep=separation,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
shuffle=False,
|
||||
)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
@ -113,16 +123,20 @@ class Circles(NumpyDataset):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
num_samples: int = 300,
|
||||
noise: float = 0.3,
|
||||
factor: float = 0.8,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_circles(num_samples,
|
||||
seed: None | int = 0,
|
||||
):
|
||||
x, y = make_circles(
|
||||
num_samples,
|
||||
noise=noise,
|
||||
factor=factor,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
shuffle=False,
|
||||
)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
@ -136,12 +150,16 @@ class Moons(NumpyDataset):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
num_samples: int = 300,
|
||||
noise: float = 0.3,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_moons(num_samples,
|
||||
seed: None | int = 0,
|
||||
):
|
||||
x, y = make_moons(
|
||||
num_samples,
|
||||
noise=noise,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
shuffle=False,
|
||||
)
|
||||
super().__init__(x, y)
|
||||
|
@ -36,6 +36,7 @@ Description:
|
||||
are determined by analytic chemistry.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
@ -81,13 +82,11 @@ class Tecator(ProtoDataset):
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
print("Making directories...")
|
||||
logging.debug("Making directories...")
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
os.makedirs(self.processed_folder, exist_ok=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Downloading...")
|
||||
logging.debug("Downloading...")
|
||||
for fileid, md5 in self._resources:
|
||||
filename = "tecator.npz"
|
||||
download_file_from_google_drive(fileid,
|
||||
@ -95,8 +94,7 @@ class Tecator(ProtoDataset):
|
||||
filename=filename,
|
||||
md5=md5)
|
||||
|
||||
if self.verbose:
|
||||
print("Processing...")
|
||||
logging.debug("Processing...")
|
||||
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||
allow_pickle=False) as f:
|
||||
x_train, y_train = f["x_train"], f["y_train"]
|
||||
@ -117,5 +115,4 @@ class Tecator(ProtoDataset):
|
||||
"wb") as f:
|
||||
torch.save(test_set, f)
|
||||
|
||||
if self.verbose:
|
||||
print("Done!")
|
||||
logging.debug("Done!")
|
||||
|
Loading…
Reference in New Issue
Block a user