Automatic Formatting.

This commit is contained in:
Alexander Engelsberger
2021-04-23 17:24:53 +02:00
parent e1d56595c1
commit 7c30ffe2c7
28 changed files with 393 additions and 321 deletions

View File

@@ -1 +0,0 @@
from .colors import color_scheme, get_legend_handles

View File

@@ -1,13 +1,13 @@
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
from typing import Dict, List
from collections import defaultdict
from typing import Dict, List
from matplotlib.figure import Figure
from matplotlib.artist import Artist
from matplotlib.animation import ArtistAnimation
from matplotlib.artist import Artist
from matplotlib.figure import Figure
__version__ = '0.2.0'
__version__ = "0.2.0"
class Camera:
@@ -19,7 +19,7 @@ class Camera:
self._offsets: Dict[str, Dict[int, int]] = {
k: defaultdict(int)
for k in
['collections', 'patches', 'lines', 'texts', 'artists', 'images']
["collections", "patches", "lines", "texts", "artists", "images"]
}
self._photos: List[List[Artist]] = []

View File

@@ -1,13 +1,14 @@
"""ProtoFlow color utilities."""
from matplotlib import cm
from matplotlib.colors import Normalize
from matplotlib.colors import to_hex
from matplotlib.colors import to_rgb
import matplotlib.lines as mlines
from matplotlib import cm
from matplotlib.colors import Normalize, to_hex, to_rgb
def color_scheme(n, cmap="viridis", form="hex", tikz=False,
def color_scheme(n,
cmap="viridis",
form="hex",
tikz=False,
zero_indexed=False):
"""Return *n* colors from the color scheme.
@@ -57,13 +58,16 @@ def get_legend_handles(labels, marker="dots", zero_indexed=False):
zero_indexed=zero_indexed)
for label, color in zip(labels, colors.values()):
if marker == "dots":
handle = mlines.Line2D([], [],
color="white",
markerfacecolor=color,
marker="o",
markersize=10,
markeredgecolor="k",
label=label)
handle = mlines.Line2D(
[],
[],
color="white",
markerfacecolor=color,
marker="o",
markersize=10,
markeredgecolor="k",
label=label,
)
else:
handle = mlines.Line2D([], [],
color=color,

View File

@@ -11,17 +11,17 @@ import numpy as np
def progressbar(title, value, end, bar_width=20):
percent = float(value) / end
arrow = '=' * int(round(percent * bar_width) - 1) + '>'
spaces = '.' * (bar_width - len(arrow))
sys.stdout.write('\r{}: [{}] {}%'.format(title, arrow + spaces,
arrow = "=" * int(round(percent * bar_width) - 1) + ">"
spaces = "." * (bar_width - len(arrow))
sys.stdout.write("\r{}: [{}] {}%".format(title, arrow + spaces,
int(round(percent * 100))))
sys.stdout.flush()
if percent == 1.0:
print()
def prettify_string(inputs, start='', sep=' ', end='\n'):
outputs = start + ' '.join(inputs.split()) + end
def prettify_string(inputs, start="", sep=" ", end="\n"):
outputs = start + " ".join(inputs.split()) + end
return outputs
@@ -29,22 +29,22 @@ def pretty_print(inputs):
print(prettify_string(inputs))
def writelog(self, *logs, logdir='./logs', logfile='run.txt'):
def writelog(self, *logs, logdir="./logs", logfile="run.txt"):
f = os.path.join(logdir, logfile)
with open(f, 'a+') as fh:
with open(f, "a+") as fh:
for log in logs:
fh.write(log)
fh.write('\n')
fh.write("\n")
def start_tensorboard(self, logdir='./logs'):
cmd = f'tensorboard --logdir={logdir} --port=6006'
def start_tensorboard(self, logdir="./logs"):
cmd = f"tensorboard --logdir={logdir} --port=6006"
os.system(cmd)
def make_directory(save_dir):
if not os.path.exists(save_dir):
print(f'Making directory {save_dir}.')
print(f"Making directory {save_dir}.")
os.mkdir(save_dir)
@@ -52,36 +52,36 @@ def make_gif(filenames, duration, output_file=None):
try:
import imageio
except ModuleNotFoundError as e:
print('Please install Protoflow with [other] extra requirements.')
print("Please install Protoflow with [other] extra requirements.")
raise (e)
images = list()
for filename in filenames:
images.append(imageio.imread(filename))
if not output_file:
output_file = f'makegif.gif'
output_file = f"makegif.gif"
if images:
imageio.mimwrite(output_file, images, duration=duration)
def gif_from_dir(directory,
duration,
prefix='',
prefix="",
output_file=None,
verbose=True):
images = os.listdir(directory)
if verbose:
print(f'Making gif from {len(images)} images under {directory}.')
print(f"Making gif from {len(images)} images under {directory}.")
filenames = list()
# Sort images
images = sorted(
images,
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, '')))
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, "")))
for image in images:
fname = os.path.join(directory, image)
filenames.append(fname)
if not output_file:
output_file = os.path.join(directory, 'makegif.gif')
output_file = os.path.join(directory, "makegif.gif")
make_gif(filenames=filenames, duration=duration, output_file=output_file)
@@ -95,12 +95,12 @@ def predict_and_score(clf,
x_test,
y_test,
verbose=False,
title='Test accuracy'):
title="Test accuracy"):
y_pred = clf.predict(x_test)
accuracy = np.sum(y_test == y_pred)
normalized_acc = accuracy / float(len(y_test))
if verbose:
print(f'{title}: {normalized_acc * 100:06.04f}%')
print(f"{title}: {normalized_acc * 100:06.04f}%")
return normalized_acc
@@ -124,6 +124,7 @@ def replace_in(arr, replacement_dict, inplace=False):
new_arr = arr
else:
import copy
new_arr = copy.deepcopy(arr)
for k, v in replacement_dict.items():
new_arr[arr == k] = v
@@ -135,7 +136,7 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
preserve the class distribution in subsamples of the dataset.
"""
if train + val > 1.0:
raise ValueError('Invalid split values for train and val.')
raise ValueError("Invalid split values for train and val.")
Y = data[:, -1]
labels = set(Y)
hist = dict()
@@ -183,20 +184,20 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
return train_data, val_data, test_data
def class_histogram(data, title='Untitled'):
def class_histogram(data, title="Untitled"):
plt.figure(title)
plt.clf()
plt.title(title)
dist, counts = np.unique(data[:, -1], return_counts=True)
plt.bar(dist, counts)
plt.xticks(dist)
print('Call matplotlib.pyplot.show() to see the plot.')
print("Call matplotlib.pyplot.show() to see the plot.")
def ntimer(n=10):
"""Wraps a function which wraps another function to time it."""
if n < 1:
raise (Exception(f'Invalid n = {n} given.'))
raise (Exception(f"Invalid n = {n} given."))
def timer(func):
"""Wraps `func` with a timer and returns the wrapped `func`."""
@@ -207,7 +208,7 @@ def ntimer(n=10):
rv = func(*args, **kwargs)
after = time()
elapsed = after - before
print(f'Elapsed: {elapsed*1e3:02.02f} ms')
print(f"Elapsed: {elapsed*1e3:02.02f} ms")
return rv
return wrapper
@@ -228,15 +229,15 @@ def memoize(verbose=True):
t = (pickle.dumps(args), pickle.dumps(kwargs))
if t not in cache:
if verbose:
print(f'Adding NEW rv {func.__name__}{args}{kwargs} '
'to cache.')
print(f"Adding NEW rv {func.__name__}{args}{kwargs} "
"to cache.")
cache[t] = func(*args, **kwargs)
else:
if verbose:
print(f'Using OLD rv {func.__name__}{args}{kwargs} '
'from cache.')
print(f"Using OLD rv {func.__name__}{args}{kwargs} "
"from cache.")
return cache[t]
return wrapper
return memoizer
return memoizer