Automatic Formatting.
This commit is contained in:
@@ -5,8 +5,6 @@
|
||||
# #############################################
|
||||
__version__ = "0.3.0-dev0"
|
||||
|
||||
from prototorch import datasets, functions, modules
|
||||
|
||||
__all_core__ = [
|
||||
"datasets",
|
||||
"functions",
|
||||
@@ -17,6 +15,7 @@ __all_core__ = [
|
||||
# Plugin Loader
|
||||
# #############################################
|
||||
import pkgutil
|
||||
|
||||
import pkg_resources
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||
@@ -25,7 +24,8 @@ __path__ = pkgutil.extend_path(__path__, __name__)
|
||||
def discover_plugins():
|
||||
return {
|
||||
entry_point.name: entry_point.load()
|
||||
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
|
||||
for entry_point in pkg_resources.iter_entry_points(
|
||||
"prototorch.plugins")
|
||||
}
|
||||
|
||||
|
||||
@@ -33,14 +33,12 @@ discovered_plugins = discover_plugins()
|
||||
locals().update(discovered_plugins)
|
||||
|
||||
# Generate combines __version__ and __all__
|
||||
version_plugins = "\n".join(
|
||||
[
|
||||
"- " + name + ": v" + plugin.__version__
|
||||
for name, plugin in discovered_plugins.items()
|
||||
]
|
||||
)
|
||||
version_plugins = "\n".join([
|
||||
"- " + name + ": v" + plugin.__version__
|
||||
for name, plugin in discovered_plugins.items()
|
||||
])
|
||||
if version_plugins != "":
|
||||
version_plugins = "\nPlugins: \n" + version_plugins
|
||||
|
||||
version = "core: v" + __version__ + version_plugins
|
||||
__all__ = __all_core__ + list(discovered_plugins.keys())
|
||||
__all__ = __all_core__ + list(discovered_plugins.keys())
|
||||
|
@@ -3,5 +3,5 @@
|
||||
from .tecator import Tecator
|
||||
|
||||
__all__ = [
|
||||
'Tecator',
|
||||
"Tecator",
|
||||
]
|
||||
|
@@ -52,7 +52,8 @@ class Tecator(ProtoDataset):
|
||||
"""
|
||||
|
||||
_resources = [
|
||||
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"),
|
||||
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0",
|
||||
"ba5607c580d0f91bb27dc29d13c2f8df"),
|
||||
] # (google_storage_id, md5hash)
|
||||
classes = ["0 - low_fat", "1 - high_fat"]
|
||||
|
||||
@@ -74,15 +75,15 @@ class Tecator(ProtoDataset):
|
||||
print("Downloading...")
|
||||
for fileid, md5 in self._resources:
|
||||
filename = "tecator.npz"
|
||||
download_file_from_google_drive(
|
||||
fileid, root=self.raw_folder, filename=filename, md5=md5
|
||||
)
|
||||
download_file_from_google_drive(fileid,
|
||||
root=self.raw_folder,
|
||||
filename=filename,
|
||||
md5=md5)
|
||||
|
||||
if self.verbose:
|
||||
print("Processing...")
|
||||
with np.load(
|
||||
os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False
|
||||
) as f:
|
||||
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||
allow_pickle=False) as f:
|
||||
x_train, y_train = f["x_train"], f["y_train"]
|
||||
x_test, y_test = f["x_test"], f["y_test"]
|
||||
training_set = [
|
||||
@@ -94,9 +95,11 @@ class Tecator(ProtoDataset):
|
||||
torch.tensor(y_test),
|
||||
]
|
||||
|
||||
with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
|
||||
with open(os.path.join(self.processed_folder, self.training_file),
|
||||
"wb") as f:
|
||||
torch.save(training_set, f)
|
||||
with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
|
||||
with open(os.path.join(self.processed_folder, self.test_file),
|
||||
"wb") as f:
|
||||
torch.save(test_set, f)
|
||||
|
||||
if self.verbose:
|
||||
|
@@ -4,9 +4,9 @@ from .activations import identity, sigmoid_beta, swish_beta
|
||||
from .competitions import knnc, wtac
|
||||
|
||||
__all__ = [
|
||||
'identity',
|
||||
'sigmoid_beta',
|
||||
'swish_beta',
|
||||
'knnc',
|
||||
'wtac',
|
||||
"identity",
|
||||
"sigmoid_beta",
|
||||
"swish_beta",
|
||||
"knnc",
|
||||
"wtac",
|
||||
]
|
||||
|
@@ -61,4 +61,4 @@ def get_activation(funcname):
|
||||
return funcname
|
||||
if funcname in ACTIVATIONS:
|
||||
return ACTIVATIONS.get(funcname)
|
||||
raise NameError(f'Activation {funcname} was not found.')
|
||||
raise NameError(f"Activation {funcname} was not found.")
|
||||
|
@@ -12,7 +12,7 @@ def stratified_min(distances, labels):
|
||||
return distances
|
||||
batch_size = distances.size()[0]
|
||||
winning_distances = torch.zeros(nclasses, batch_size)
|
||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
||||
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
for i, cl in enumerate(clabels):
|
||||
# cdists = distances.T[labels == cl]
|
||||
|
@@ -1,12 +1,10 @@
|
||||
"""ProtoTorch distance functions."""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.helper import (
|
||||
equal_int_shape,
|
||||
_int_and_mixed_shape,
|
||||
_check_shapes,
|
||||
)
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||
equal_int_shape)
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
|
@@ -23,7 +23,7 @@ def predict_label(y_pred, plabels):
|
||||
|
||||
def mixed_shape(inputs):
|
||||
if not torch.is_tensor(inputs):
|
||||
raise ValueError('Input must be a tensor.')
|
||||
raise ValueError("Input must be a tensor.")
|
||||
else:
|
||||
int_shape = list(inputs.shape)
|
||||
# sometimes int_shape returns mixed integer types
|
||||
@@ -39,11 +39,11 @@ def mixed_shape(inputs):
|
||||
def equal_int_shape(shape_1, shape_2):
|
||||
if not isinstance(shape_1,
|
||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||
raise ValueError('Input shapes must list or tuple.')
|
||||
raise ValueError("Input shapes must list or tuple.")
|
||||
for shape in [shape_1, shape_2]:
|
||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||
raise ValueError(
|
||||
'Input shapes must be list or tuple of int and None values.')
|
||||
"Input shapes must be list or tuple of int and None values.")
|
||||
|
||||
if len(shape_1) != len(shape_2):
|
||||
return False
|
||||
|
@@ -104,4 +104,4 @@ def get_initializer(funcname):
|
||||
return funcname
|
||||
if funcname in INITIALIZERS:
|
||||
return INITIALIZERS.get(funcname)
|
||||
raise NameError(f'Initializer {funcname} was not found.')
|
||||
raise NameError(f"Initializer {funcname} was not found.")
|
||||
|
@@ -11,7 +11,7 @@ def _get_dp_dm(distances, targets, plabels):
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
not_matcher = torch.bitwise_not(matcher)
|
||||
|
||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||
d_matching = torch.where(matcher, distances, inf)
|
||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||
|
@@ -1,7 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
|
||||
|
@@ -3,5 +3,5 @@
|
||||
from .prototypes import Prototypes1D
|
||||
|
||||
__all__ = [
|
||||
'Prototypes1D',
|
||||
"Prototypes1D",
|
||||
]
|
||||
|
@@ -7,7 +7,7 @@ from prototorch.functions.losses import glvq_loss
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
@@ -37,4 +37,4 @@ class NeuralGasEnergy(torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
||||
return torch.exp(-rankings / lm)
|
||||
|
@@ -1,9 +1,11 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from torch import nn
|
||||
|
||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||
tangent_distance)
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
@@ -71,7 +73,7 @@ class GTLVQ(nn.Module):
|
||||
subspace_data=None,
|
||||
prototype_data=None,
|
||||
subspace_size=256,
|
||||
tangent_projection_type='local',
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
feature_dim=256,
|
||||
):
|
||||
@@ -82,37 +84,39 @@ class GTLVQ(nn.Module):
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError('Init Data must be specified!')
|
||||
raise ValueError("Init Data must be specified!")
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
||||
if self.tpt == "local" or self.tpt == "local_proj":
|
||||
self.init_local_subspace(subspace_data)
|
||||
elif self.tpt == 'global':
|
||||
elif self.tpt == "global":
|
||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
# Hypothesis-Margin-Classifier
|
||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=prototype_data)
|
||||
self.cls = Prototypes1D(
|
||||
input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=prototype_data,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Tangent Projection
|
||||
if self.tpt == 'local_proj':
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
if self.tpt == "local_proj":
|
||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2))
|
||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
||||
|
||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||
self.feature_dim)
|
||||
return proj_x, dis
|
||||
elif self.tpt == "local":
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2))
|
||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||
self.subspaces)
|
||||
elif self.tpt == "gloabl":
|
||||
@@ -127,25 +131,27 @@ class GTLVQ(nn.Module):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True)
|
||||
self.subspaces = (torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True))
|
||||
|
||||
def init_local_subspace(self, data):
|
||||
_, _, v = torch.svd(data)
|
||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||
self.num_protos, 0)
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True)
|
||||
self.subspaces = (torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True))
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
||||
x, projected_prototypes = (
|
||||
x @ self.subspaces,
|
||||
self.cls.prototypes @ self.subspaces,
|
||||
)
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_projection(self,
|
||||
signals):
|
||||
def local_tangent_projection(self, signals):
|
||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
@@ -183,8 +189,7 @@ class GTLVQ(nn.Module):
|
||||
def orthogonalize_subspace(self):
|
||||
if self.subspaces is not None:
|
||||
with torch.no_grad():
|
||||
ortho_subpsaces = orthogonalization(
|
||||
self.subspaces
|
||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
||||
self.subspaces)
|
||||
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||||
if self.tpt == "global" else
|
||||
torch.nn.init.orthogonal_(self.subspaces))
|
||||
self.subspaces.copy_(ortho_subpsaces)
|
||||
|
@@ -29,14 +29,16 @@ class Prototypes1D(_Prototypes):
|
||||
|
||||
TODO Complete this doc-string.
|
||||
"""
|
||||
def __init__(self,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="ones",
|
||||
prototype_distribution=None,
|
||||
data=None,
|
||||
dtype=torch.float32,
|
||||
one_hot_labels=False,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="ones",
|
||||
prototype_distribution=None,
|
||||
data=None,
|
||||
dtype=torch.float32,
|
||||
one_hot_labels=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# Convert tensors to python lists before processing
|
||||
if prototype_distribution is not None:
|
||||
|
@@ -1 +0,0 @@
|
||||
from .colors import color_scheme, get_legend_handles
|
||||
|
@@ -1,13 +1,13 @@
|
||||
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
|
||||
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.artist import Artist
|
||||
from matplotlib.animation import ArtistAnimation
|
||||
from matplotlib.artist import Artist
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
__version__ = '0.2.0'
|
||||
__version__ = "0.2.0"
|
||||
|
||||
|
||||
class Camera:
|
||||
@@ -19,7 +19,7 @@ class Camera:
|
||||
self._offsets: Dict[str, Dict[int, int]] = {
|
||||
k: defaultdict(int)
|
||||
for k in
|
||||
['collections', 'patches', 'lines', 'texts', 'artists', 'images']
|
||||
["collections", "patches", "lines", "texts", "artists", "images"]
|
||||
}
|
||||
self._photos: List[List[Artist]] = []
|
||||
|
||||
|
@@ -1,13 +1,14 @@
|
||||
"""ProtoFlow color utilities."""
|
||||
|
||||
from matplotlib import cm
|
||||
from matplotlib.colors import Normalize
|
||||
from matplotlib.colors import to_hex
|
||||
from matplotlib.colors import to_rgb
|
||||
import matplotlib.lines as mlines
|
||||
from matplotlib import cm
|
||||
from matplotlib.colors import Normalize, to_hex, to_rgb
|
||||
|
||||
|
||||
def color_scheme(n, cmap="viridis", form="hex", tikz=False,
|
||||
def color_scheme(n,
|
||||
cmap="viridis",
|
||||
form="hex",
|
||||
tikz=False,
|
||||
zero_indexed=False):
|
||||
"""Return *n* colors from the color scheme.
|
||||
|
||||
@@ -57,13 +58,16 @@ def get_legend_handles(labels, marker="dots", zero_indexed=False):
|
||||
zero_indexed=zero_indexed)
|
||||
for label, color in zip(labels, colors.values()):
|
||||
if marker == "dots":
|
||||
handle = mlines.Line2D([], [],
|
||||
color="white",
|
||||
markerfacecolor=color,
|
||||
marker="o",
|
||||
markersize=10,
|
||||
markeredgecolor="k",
|
||||
label=label)
|
||||
handle = mlines.Line2D(
|
||||
[],
|
||||
[],
|
||||
color="white",
|
||||
markerfacecolor=color,
|
||||
marker="o",
|
||||
markersize=10,
|
||||
markeredgecolor="k",
|
||||
label=label,
|
||||
)
|
||||
else:
|
||||
handle = mlines.Line2D([], [],
|
||||
color=color,
|
||||
|
@@ -11,17 +11,17 @@ import numpy as np
|
||||
|
||||
def progressbar(title, value, end, bar_width=20):
|
||||
percent = float(value) / end
|
||||
arrow = '=' * int(round(percent * bar_width) - 1) + '>'
|
||||
spaces = '.' * (bar_width - len(arrow))
|
||||
sys.stdout.write('\r{}: [{}] {}%'.format(title, arrow + spaces,
|
||||
arrow = "=" * int(round(percent * bar_width) - 1) + ">"
|
||||
spaces = "." * (bar_width - len(arrow))
|
||||
sys.stdout.write("\r{}: [{}] {}%".format(title, arrow + spaces,
|
||||
int(round(percent * 100))))
|
||||
sys.stdout.flush()
|
||||
if percent == 1.0:
|
||||
print()
|
||||
|
||||
|
||||
def prettify_string(inputs, start='', sep=' ', end='\n'):
|
||||
outputs = start + ' '.join(inputs.split()) + end
|
||||
def prettify_string(inputs, start="", sep=" ", end="\n"):
|
||||
outputs = start + " ".join(inputs.split()) + end
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -29,22 +29,22 @@ def pretty_print(inputs):
|
||||
print(prettify_string(inputs))
|
||||
|
||||
|
||||
def writelog(self, *logs, logdir='./logs', logfile='run.txt'):
|
||||
def writelog(self, *logs, logdir="./logs", logfile="run.txt"):
|
||||
f = os.path.join(logdir, logfile)
|
||||
with open(f, 'a+') as fh:
|
||||
with open(f, "a+") as fh:
|
||||
for log in logs:
|
||||
fh.write(log)
|
||||
fh.write('\n')
|
||||
fh.write("\n")
|
||||
|
||||
|
||||
def start_tensorboard(self, logdir='./logs'):
|
||||
cmd = f'tensorboard --logdir={logdir} --port=6006'
|
||||
def start_tensorboard(self, logdir="./logs"):
|
||||
cmd = f"tensorboard --logdir={logdir} --port=6006"
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def make_directory(save_dir):
|
||||
if not os.path.exists(save_dir):
|
||||
print(f'Making directory {save_dir}.')
|
||||
print(f"Making directory {save_dir}.")
|
||||
os.mkdir(save_dir)
|
||||
|
||||
|
||||
@@ -52,36 +52,36 @@ def make_gif(filenames, duration, output_file=None):
|
||||
try:
|
||||
import imageio
|
||||
except ModuleNotFoundError as e:
|
||||
print('Please install Protoflow with [other] extra requirements.')
|
||||
print("Please install Protoflow with [other] extra requirements.")
|
||||
raise (e)
|
||||
|
||||
images = list()
|
||||
for filename in filenames:
|
||||
images.append(imageio.imread(filename))
|
||||
if not output_file:
|
||||
output_file = f'makegif.gif'
|
||||
output_file = f"makegif.gif"
|
||||
if images:
|
||||
imageio.mimwrite(output_file, images, duration=duration)
|
||||
|
||||
|
||||
def gif_from_dir(directory,
|
||||
duration,
|
||||
prefix='',
|
||||
prefix="",
|
||||
output_file=None,
|
||||
verbose=True):
|
||||
images = os.listdir(directory)
|
||||
if verbose:
|
||||
print(f'Making gif from {len(images)} images under {directory}.')
|
||||
print(f"Making gif from {len(images)} images under {directory}.")
|
||||
filenames = list()
|
||||
# Sort images
|
||||
images = sorted(
|
||||
images,
|
||||
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, '')))
|
||||
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, "")))
|
||||
for image in images:
|
||||
fname = os.path.join(directory, image)
|
||||
filenames.append(fname)
|
||||
if not output_file:
|
||||
output_file = os.path.join(directory, 'makegif.gif')
|
||||
output_file = os.path.join(directory, "makegif.gif")
|
||||
make_gif(filenames=filenames, duration=duration, output_file=output_file)
|
||||
|
||||
|
||||
@@ -95,12 +95,12 @@ def predict_and_score(clf,
|
||||
x_test,
|
||||
y_test,
|
||||
verbose=False,
|
||||
title='Test accuracy'):
|
||||
title="Test accuracy"):
|
||||
y_pred = clf.predict(x_test)
|
||||
accuracy = np.sum(y_test == y_pred)
|
||||
normalized_acc = accuracy / float(len(y_test))
|
||||
if verbose:
|
||||
print(f'{title}: {normalized_acc * 100:06.04f}%')
|
||||
print(f"{title}: {normalized_acc * 100:06.04f}%")
|
||||
return normalized_acc
|
||||
|
||||
|
||||
@@ -124,6 +124,7 @@ def replace_in(arr, replacement_dict, inplace=False):
|
||||
new_arr = arr
|
||||
else:
|
||||
import copy
|
||||
|
||||
new_arr = copy.deepcopy(arr)
|
||||
for k, v in replacement_dict.items():
|
||||
new_arr[arr == k] = v
|
||||
@@ -135,7 +136,7 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
|
||||
preserve the class distribution in subsamples of the dataset.
|
||||
"""
|
||||
if train + val > 1.0:
|
||||
raise ValueError('Invalid split values for train and val.')
|
||||
raise ValueError("Invalid split values for train and val.")
|
||||
Y = data[:, -1]
|
||||
labels = set(Y)
|
||||
hist = dict()
|
||||
@@ -183,20 +184,20 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def class_histogram(data, title='Untitled'):
|
||||
def class_histogram(data, title="Untitled"):
|
||||
plt.figure(title)
|
||||
plt.clf()
|
||||
plt.title(title)
|
||||
dist, counts = np.unique(data[:, -1], return_counts=True)
|
||||
plt.bar(dist, counts)
|
||||
plt.xticks(dist)
|
||||
print('Call matplotlib.pyplot.show() to see the plot.')
|
||||
print("Call matplotlib.pyplot.show() to see the plot.")
|
||||
|
||||
|
||||
def ntimer(n=10):
|
||||
"""Wraps a function which wraps another function to time it."""
|
||||
if n < 1:
|
||||
raise (Exception(f'Invalid n = {n} given.'))
|
||||
raise (Exception(f"Invalid n = {n} given."))
|
||||
|
||||
def timer(func):
|
||||
"""Wraps `func` with a timer and returns the wrapped `func`."""
|
||||
@@ -207,7 +208,7 @@ def ntimer(n=10):
|
||||
rv = func(*args, **kwargs)
|
||||
after = time()
|
||||
elapsed = after - before
|
||||
print(f'Elapsed: {elapsed*1e3:02.02f} ms')
|
||||
print(f"Elapsed: {elapsed*1e3:02.02f} ms")
|
||||
return rv
|
||||
|
||||
return wrapper
|
||||
@@ -228,15 +229,15 @@ def memoize(verbose=True):
|
||||
t = (pickle.dumps(args), pickle.dumps(kwargs))
|
||||
if t not in cache:
|
||||
if verbose:
|
||||
print(f'Adding NEW rv {func.__name__}{args}{kwargs} '
|
||||
'to cache.')
|
||||
print(f"Adding NEW rv {func.__name__}{args}{kwargs} "
|
||||
"to cache.")
|
||||
cache[t] = func(*args, **kwargs)
|
||||
else:
|
||||
if verbose:
|
||||
print(f'Using OLD rv {func.__name__}{args}{kwargs} '
|
||||
'from cache.')
|
||||
print(f"Using OLD rv {func.__name__}{args}{kwargs} "
|
||||
"from cache.")
|
||||
return cache[t]
|
||||
|
||||
return wrapper
|
||||
|
||||
return memoizer
|
||||
return memoizer
|
||||
|
Reference in New Issue
Block a user