diff --git a/LICENSE b/LICENSE index ebc4b67..3fa8fdf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,7 @@ MIT License -Copyright (c) 2020 si-cim +Copyright (c) 2020 Saxon Institute for Computational Intelligence and Machine +Learning (SICIM) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/docs/source/conf.py b/docs/source/conf.py index d620239..ff51cff 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -120,7 +120,7 @@ html_css_files = [ # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = "protoflowdoc" +htmlhelp_basename = "prototorchdoc" # -- Options for LaTeX output --------------------------------------------- diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index a1a75ee..d6ae098 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -32,6 +32,12 @@ class LiteralCompInitializer(AbstractComponentsInitializer): def generate(self, num_components: int = 0): """Ignore `num_components` and simply return `self.components`.""" + provided_num_components = len(self.components) + if provided_num_components != num_components: + wmsg = f"The number of components ({provided_num_components}) " \ + f"provided to {self.__class__.__name__} " \ + f"does not match the expected number ({num_components})." + warnings.warn(wmsg) if not isinstance(self.components, torch.Tensor): wmsg = f"Converting components to {torch.Tensor}..." warnings.warn(wmsg) @@ -231,6 +237,8 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): components = torch.tensor([]) for k, v in distribution.items(): stratified_data = self.data[self.targets == k] + if len(stratified_data) == 0: + raise ValueError(f"No data available for class {k}.") initializer = self.subinit_type( stratified_data, noise=self.noise, @@ -457,7 +465,15 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): return self.generate_end_hook(weights) -class EyeTransformInitializer(AbstractLinearTransformInitializer): +class RandomLinearTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with random values.""" + + def generate(self, in_dim: int, out_dim: int): + weights = torch.rand(in_dim, out_dim) + return self.generate_end_hook(weights) + + +class EyeLinearTransformInitializer(AbstractLinearTransformInitializer): """Initialize a matrix with the largest possible identity matrix.""" def generate(self, in_dim: int, out_dim: int): @@ -496,6 +512,13 @@ class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): return self.generate_end_hook(weights) +class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer): + """'Generate' the provided weights.""" + + def generate(self, in_dim: int, out_dim: int): + return self.generate_end_hook(self.data) + + # Aliases - Components CACI = ClassAwareCompInitializer DACI = DataAwareCompInitializer @@ -524,7 +547,9 @@ RRI = RandomReasoningsInitializer ZRI = ZerosReasoningsInitializer # Aliases - Transforms -Eye = EyeTransformInitializer +ELTI = Eye = EyeLinearTransformInitializer OLTI = OnesLinearTransformInitializer +RLTI = RandomLinearTransformInitializer ZLTI = ZerosLinearTransformInitializer PCALTI = PCALinearTransformInitializer +LLTI = LiteralLinearTransformInitializer diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py index ed35510..873538a 100644 --- a/prototorch/core/losses.py +++ b/prototorch/core/losses.py @@ -107,14 +107,24 @@ def margin_loss(y_pred, y_true, margin=0.3): class GLVQLoss(torch.nn.Module): - def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs): + def __init__(self, + margin=0.0, + transfer_fn="identity", + beta=10, + add_dp=False, + **kwargs): super().__init__(**kwargs) self.margin = margin self.transfer_fn = get_activation(transfer_fn) self.beta = torch.tensor(beta) + self.add_dp = add_dp def forward(self, outputs, targets, plabels): - mu = glvq_loss(outputs, targets, prototype_labels=plabels) + # mu = glvq_loss(outputs, targets, plabels) + dp, dm = _get_dp_dm(outputs, targets, plabels) + mu = (dp - dm) / (dp + dm) + if self.add_dp: + mu = mu + dp batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta) return batch_loss.sum() diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py index 7ad31c9..8cd080e 100644 --- a/prototorch/core/transforms.py +++ b/prototorch/core/transforms.py @@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter from .initializers import ( AbstractLinearTransformInitializer, - EyeTransformInitializer, + EyeLinearTransformInitializer, ) @@ -16,7 +16,7 @@ class LinearTransform(torch.nn.Module): in_dim: int, out_dim: int, initializer: - AbstractLinearTransformInitializer = EyeTransformInitializer()): + AbstractLinearTransformInitializer = EyeLinearTransformInitializer()): super().__init__() self.set_weights(in_dim, out_dim, initializer) @@ -32,12 +32,15 @@ class LinearTransform(torch.nn.Module): in_dim: int, out_dim: int, initializer: - AbstractLinearTransformInitializer = EyeTransformInitializer()): + AbstractLinearTransformInitializer = EyeLinearTransformInitializer()): weights = initializer.generate(in_dim, out_dim) self._register_weights(weights) def forward(self, x): - return x @ self.weights + return x @ self._weights + + def extra_repr(self): + return f"weights: (shape: {tuple(self._weights.shape)})" # Aliases diff --git a/prototorch/utils/__init__.py b/prototorch/utils/__init__.py index 26ccedd..ac89f24 100644 --- a/prototorch/utils/__init__.py +++ b/prototorch/utils/__init__.py @@ -1,6 +1,11 @@ -"""ProtoFlow utils module""" +"""ProtoTorch utils module""" -from .colors import hex_to_rgb, rgb_to_hex +from .colors import ( + get_colors, + get_legend_handles, + hex_to_rgb, + rgb_to_hex, +) from .utils import ( mesh2d, parse_data_arg, diff --git a/prototorch/utils/colors.py b/prototorch/utils/colors.py index 61ad1a0..169ff65 100644 --- a/prototorch/utils/colors.py +++ b/prototorch/utils/colors.py @@ -1,4 +1,13 @@ -"""ProtoFlow color utilities""" +"""ProtoTorch color utilities""" + +import matplotlib.lines as mlines +import torch +from matplotlib import cm +from matplotlib.colors import ( + Normalize, + to_hex, + to_rgb, +) def hex_to_rgb(hex_values): @@ -13,3 +22,39 @@ def rgb_to_hex(rgb_values): for v in rgb_values: c = "%02x%02x%02x" % tuple(v) yield c + + +def get_colors(vmax, vmin=0, cmap="viridis"): + cmap = cm.get_cmap(cmap) + colornorm = Normalize(vmin=vmin, vmax=vmax) + colors = dict() + for c in range(vmin, vmax + 1): + colors[c] = to_hex(cmap(colornorm(c))) + return colors + + +def get_legend_handles(colors, labels, marker="dots", zero_indexed=False): + handles = list() + for color, label in zip(colors.values(), labels): + if marker == "dots": + handle = mlines.Line2D( + xdata=[], + ydata=[], + label=label, + color="white", + markerfacecolor=color, + marker="o", + markersize=10, + markeredgecolor="k", + ) + else: + handle = mlines.Line2D( + xdata=[], + ydata=[], + label=label, + color=color, + marker="", + markersize=15, + ) + handles.append(handle) + return handles diff --git a/prototorch/utils/utils.py b/prototorch/utils/utils.py index 46ed01b..6e4d6c8 100644 --- a/prototorch/utils/utils.py +++ b/prototorch/utils/utils.py @@ -1,4 +1,4 @@ -"""ProtoFlow utilities""" +"""ProtoTorch utilities""" import warnings from typing import ( diff --git a/setup.py b/setup.py index b9cb1ec..66aa3fd 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ INSTALL_REQUIRES = [ "torchvision>=0.7.2", "numpy>=1.9.1", "sklearn", + "matplotlib", ] DATASETS = [ "requests", @@ -40,7 +41,6 @@ DOCS = [ "sphinx-autodoc-typehints", ] EXAMPLES = [ - "matplotlib", "torchinfo", ] TESTS = [ diff --git a/tests/test_core.py b/tests/test_core.py index 296bf74..09a4802 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -245,20 +245,20 @@ def test_random_reasonings_init_channels_not_first(): # Transform initializers def test_eye_transform_init_square(): - t = pt.initializers.EyeTransformInitializer() + t = pt.initializers.EyeLinearTransformInitializer() I = t.generate(3, 3) assert torch.allclose(I, torch.eye(3)) def test_eye_transform_init_narrow(): - t = pt.initializers.EyeTransformInitializer() + t = pt.initializers.EyeLinearTransformInitializer() actual = t.generate(3, 2) desired = torch.Tensor([[1, 0], [0, 1], [0, 0]]) assert torch.allclose(actual, desired) def test_eye_transform_init_wide(): - t = pt.initializers.EyeTransformInitializer() + t = pt.initializers.EyeLinearTransformInitializer() actual = t.generate(2, 3) desired = torch.Tensor([[1, 0, 0], [0, 1, 0]]) assert torch.allclose(actual, desired)