chore: minor changes and version updates
This commit is contained in:
parent
6714cb7915
commit
2a85c94b55
@ -1,5 +1,7 @@
|
|||||||
"""ProtoTorch CBC example using 2D Iris data."""
|
"""ProtoTorch CBC example using 2D Iris data."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
@ -34,7 +36,7 @@ class VisCBC2D():
|
|||||||
self.resolution = 100
|
self.resolution = 100
|
||||||
self.cmap = "viridis"
|
self.cmap = "viridis"
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_train_epoch_end(self):
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
_components = self.model.components_layer._components.detach()
|
_components = self.model.components_layer._components.detach()
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
@ -94,5 +96,5 @@ if __name__ == "__main__":
|
|||||||
correct += (y_pred.argmax(1) == y).float().sum(0)
|
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||||
|
|
||||||
acc = 100 * correct / len(train_ds)
|
acc = 100 * correct / len(train_ds)
|
||||||
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||||
vis.on_epoch_end()
|
vis.on_train_epoch_end()
|
||||||
|
@ -5,8 +5,10 @@ URL:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Sequence, Union
|
from typing import Sequence
|
||||||
|
|
||||||
from sklearn.datasets import (
|
from sklearn.datasets import (
|
||||||
load_iris,
|
load_iris,
|
||||||
@ -41,9 +43,9 @@ class Iris(NumpyDataset):
|
|||||||
:param dims: select a subset of dimensions
|
:param dims: select a subset of dimensions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dims: Sequence[int] = None):
|
def __init__(self, dims: Sequence[int] | None = None):
|
||||||
x, y = load_iris(return_X_y=True)
|
x, y = load_iris(return_X_y=True)
|
||||||
if dims:
|
if dims is not None:
|
||||||
x = x[:, dims]
|
x = x[:, dims]
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
@ -56,15 +58,19 @@ class Blobs(NumpyDataset):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
num_samples: int = 300,
|
num_samples: int = 300,
|
||||||
num_features: int = 2,
|
num_features: int = 2,
|
||||||
seed: Union[None, int] = 0):
|
seed: None | int = 0,
|
||||||
x, y = make_blobs(num_samples,
|
):
|
||||||
|
x, y = make_blobs(
|
||||||
|
num_samples,
|
||||||
num_features,
|
num_features,
|
||||||
centers=None,
|
centers=None,
|
||||||
random_state=seed,
|
random_state=seed,
|
||||||
shuffle=False)
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@ -77,21 +83,24 @@ class Random(NumpyDataset):
|
|||||||
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
num_samples: int = 300,
|
num_samples: int = 300,
|
||||||
num_features: int = 2,
|
num_features: int = 2,
|
||||||
num_classes: int = 2,
|
num_classes: int = 2,
|
||||||
num_clusters: int = 2,
|
num_clusters: int = 2,
|
||||||
num_informative: Union[None, int] = None,
|
num_informative: None | int = None,
|
||||||
separation: float = 1.0,
|
separation: float = 1.0,
|
||||||
seed: Union[None, int] = 0):
|
seed: None | int = 0,
|
||||||
|
):
|
||||||
if not num_informative:
|
if not num_informative:
|
||||||
import math
|
import math
|
||||||
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||||
if num_features < num_informative:
|
if num_features < num_informative:
|
||||||
warnings.warn("Generating more features than requested.")
|
warnings.warn("Generating more features than requested.")
|
||||||
num_features = num_informative
|
num_features = num_informative
|
||||||
x, y = make_classification(num_samples,
|
x, y = make_classification(
|
||||||
|
num_samples,
|
||||||
num_features,
|
num_features,
|
||||||
n_informative=num_informative,
|
n_informative=num_informative,
|
||||||
n_redundant=0,
|
n_redundant=0,
|
||||||
@ -99,7 +108,8 @@ class Random(NumpyDataset):
|
|||||||
n_clusters_per_class=num_clusters,
|
n_clusters_per_class=num_clusters,
|
||||||
class_sep=separation,
|
class_sep=separation,
|
||||||
random_state=seed,
|
random_state=seed,
|
||||||
shuffle=False)
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@ -113,16 +123,20 @@ class Circles(NumpyDataset):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
num_samples: int = 300,
|
num_samples: int = 300,
|
||||||
noise: float = 0.3,
|
noise: float = 0.3,
|
||||||
factor: float = 0.8,
|
factor: float = 0.8,
|
||||||
seed: Union[None, int] = 0):
|
seed: None | int = 0,
|
||||||
x, y = make_circles(num_samples,
|
):
|
||||||
|
x, y = make_circles(
|
||||||
|
num_samples,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
factor=factor,
|
factor=factor,
|
||||||
random_state=seed,
|
random_state=seed,
|
||||||
shuffle=False)
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@ -136,12 +150,16 @@ class Moons(NumpyDataset):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
num_samples: int = 300,
|
num_samples: int = 300,
|
||||||
noise: float = 0.3,
|
noise: float = 0.3,
|
||||||
seed: Union[None, int] = 0):
|
seed: None | int = 0,
|
||||||
x, y = make_moons(num_samples,
|
):
|
||||||
|
x, y = make_moons(
|
||||||
|
num_samples,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
random_state=seed,
|
random_state=seed,
|
||||||
shuffle=False)
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
@ -36,6 +36,7 @@ Description:
|
|||||||
are determined by analytic chemistry.
|
are determined by analytic chemistry.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -81,13 +82,11 @@ class Tecator(ProtoDataset):
|
|||||||
if self._check_exists():
|
if self._check_exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Making directories...")
|
||||||
print("Making directories...")
|
|
||||||
os.makedirs(self.raw_folder, exist_ok=True)
|
os.makedirs(self.raw_folder, exist_ok=True)
|
||||||
os.makedirs(self.processed_folder, exist_ok=True)
|
os.makedirs(self.processed_folder, exist_ok=True)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Downloading...")
|
||||||
print("Downloading...")
|
|
||||||
for fileid, md5 in self._resources:
|
for fileid, md5 in self._resources:
|
||||||
filename = "tecator.npz"
|
filename = "tecator.npz"
|
||||||
download_file_from_google_drive(fileid,
|
download_file_from_google_drive(fileid,
|
||||||
@ -95,8 +94,7 @@ class Tecator(ProtoDataset):
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
md5=md5)
|
md5=md5)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Processing...")
|
||||||
print("Processing...")
|
|
||||||
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||||
allow_pickle=False) as f:
|
allow_pickle=False) as f:
|
||||||
x_train, y_train = f["x_train"], f["y_train"]
|
x_train, y_train = f["x_train"], f["y_train"]
|
||||||
@ -117,5 +115,4 @@ class Tecator(ProtoDataset):
|
|||||||
"wb") as f:
|
"wb") as f:
|
||||||
torch.save(test_set, f)
|
torch.save(test_set, f)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Done!")
|
||||||
print("Done!")
|
|
||||||
|
Loading…
Reference in New Issue
Block a user