Update prototypes.py

Changes:
1. Change single-quotes to double-quotes.
This commit is contained in:
Jensun Ravichandran 2021-04-15 12:35:06 +02:00
parent db842b79bb
commit 101b50f4e6

View File

@ -14,11 +14,11 @@ class _Prototypes(torch.nn.Module):
def _validate_prototype_distribution(self): def _validate_prototype_distribution(self):
if 0 in self.prototype_distribution: if 0 in self.prototype_distribution:
warnings.warn('Are you sure about the `0` in ' warnings.warn("Are you sure about the `0` in "
'`prototype_distribution`?') "`prototype_distribution`?")
def extra_repr(self): def extra_repr(self):
return f'prototypes.shape: {tuple(self.prototypes.shape)}' return f"prototypes.shape: {tuple(self.prototypes.shape)}"
def forward(self): def forward(self):
return self.prototypes, self.prototype_labels return self.prototypes, self.prototype_labels
@ -31,7 +31,7 @@ class Prototypes1D(_Prototypes):
""" """
def __init__(self, def __init__(self,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='ones', prototype_initializer="ones",
prototype_distribution=None, prototype_distribution=None,
data=None, data=None,
dtype=torch.float32, dtype=torch.float32,
@ -44,25 +44,25 @@ class Prototypes1D(_Prototypes):
prototype_distribution = prototype_distribution.tolist() prototype_distribution = prototype_distribution.tolist()
if data is None: if data is None:
if 'input_dim' not in kwargs: if "input_dim" not in kwargs:
raise NameError('`input_dim` required if ' raise NameError("`input_dim` required if "
'no `data` is provided.') "no `data` is provided.")
if prototype_distribution: if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution) kwargs_nclasses = sum(prototype_distribution)
else: else:
if 'nclasses' not in kwargs: if "nclasses" not in kwargs:
raise NameError('`prototype_distribution` required if ' raise NameError("`prototype_distribution` required if "
'both `data` and `nclasses` are not ' "both `data` and `nclasses` are not "
'provided.') "provided.")
kwargs_nclasses = kwargs.pop('nclasses') kwargs_nclasses = kwargs.pop("nclasses")
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop("input_dim")
if prototype_initializer in [ if prototype_initializer in [
'stratified_mean', 'stratified_random' "stratified_mean", "stratified_random"
]: ]:
warnings.warn( warnings.warn(
f'`prototype_initializer`: `{prototype_initializer}` ' f"`prototype_initializer`: `{prototype_initializer}` "
'requires `data`, but `data` is not provided. ' "requires `data`, but `data` is not provided. "
'Using randomly generated data instead.') "Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim) x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses) y_train = torch.arange(kwargs_nclasses)
if one_hot_labels: if one_hot_labels:
@ -75,39 +75,39 @@ class Prototypes1D(_Prototypes):
nclasses = torch.unique(y_train, dim=-1).shape[-1] nclasses = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1: if nclasses == 1:
warnings.warn('Are you sure about having one class only?') warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2: if x_train.ndim != 2:
raise ValueError('`data[0].ndim != 2`.') raise ValueError("`data[0].ndim != 2`.")
if y_train.ndim == 2: if y_train.ndim == 2:
if y_train.shape[1] == 1 and one_hot_labels: if y_train.shape[1] == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` ' raise ValueError("`one_hot_labels` is set to `True` "
'but target labels are not one-hot-encoded.') "but target labels are not one-hot-encoded.")
if y_train.shape[1] != 1 and not one_hot_labels: if y_train.shape[1] != 1 and not one_hot_labels:
raise ValueError('`one_hot_labels` is set to `False` ' raise ValueError("`one_hot_labels` is set to `False` "
'but target labels in `data` ' "but target labels in `data` "
'are one-hot-encoded.') "are one-hot-encoded.")
if y_train.ndim == 1 and one_hot_labels: if y_train.ndim == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` ' raise ValueError("`one_hot_labels` is set to `True` "
'but target labels are not one-hot-encoded.') "but target labels are not one-hot-encoded.")
# Verify input dimension if `input_dim` is provided # Verify input dimension if `input_dim` is provided
if 'input_dim' in kwargs: if "input_dim" in kwargs:
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop("input_dim")
if input_dim != x_train.shape[1]: if input_dim != x_train.shape[1]:
raise ValueError(f'Provided `input_dim`={input_dim} does ' raise ValueError(f"Provided `input_dim`={input_dim} does "
'not match data dimension ' "not match data dimension "
f'`data[0].shape[1]`={x_train.shape[1]}') f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided # Verify the number of classes if `nclasses` is provided
if 'nclasses' in kwargs: if "nclasses" in kwargs:
kwargs_nclasses = kwargs.pop('nclasses') kwargs_nclasses = kwargs.pop("nclasses")
if kwargs_nclasses != nclasses: if kwargs_nclasses != nclasses:
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does ' raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
'not match data labels ' "not match data labels "
'`torch.unique(data[1]).shape[0]`' "`torch.unique(data[1]).shape[0]`"
f'={nclasses}') f"={nclasses}")
super().__init__(**kwargs) super().__init__(**kwargs)