Update prototypes.py
Changes: 1. Change single-quotes to double-quotes.
This commit is contained in:
parent
db842b79bb
commit
101b50f4e6
@ -14,11 +14,11 @@ class _Prototypes(torch.nn.Module):
|
|||||||
|
|
||||||
def _validate_prototype_distribution(self):
|
def _validate_prototype_distribution(self):
|
||||||
if 0 in self.prototype_distribution:
|
if 0 in self.prototype_distribution:
|
||||||
warnings.warn('Are you sure about the `0` in '
|
warnings.warn("Are you sure about the `0` in "
|
||||||
'`prototype_distribution`?')
|
"`prototype_distribution`?")
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return self.prototypes, self.prototype_labels
|
return self.prototypes, self.prototype_labels
|
||||||
@ -31,7 +31,7 @@ class Prototypes1D(_Prototypes):
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='ones',
|
prototype_initializer="ones",
|
||||||
prototype_distribution=None,
|
prototype_distribution=None,
|
||||||
data=None,
|
data=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -44,25 +44,25 @@ class Prototypes1D(_Prototypes):
|
|||||||
prototype_distribution = prototype_distribution.tolist()
|
prototype_distribution = prototype_distribution.tolist()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
if 'input_dim' not in kwargs:
|
if "input_dim" not in kwargs:
|
||||||
raise NameError('`input_dim` required if '
|
raise NameError("`input_dim` required if "
|
||||||
'no `data` is provided.')
|
"no `data` is provided.")
|
||||||
if prototype_distribution:
|
if prototype_distribution:
|
||||||
kwargs_nclasses = sum(prototype_distribution)
|
kwargs_nclasses = sum(prototype_distribution)
|
||||||
else:
|
else:
|
||||||
if 'nclasses' not in kwargs:
|
if "nclasses" not in kwargs:
|
||||||
raise NameError('`prototype_distribution` required if '
|
raise NameError("`prototype_distribution` required if "
|
||||||
'both `data` and `nclasses` are not '
|
"both `data` and `nclasses` are not "
|
||||||
'provided.')
|
"provided.")
|
||||||
kwargs_nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop("nclasses")
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop("input_dim")
|
||||||
if prototype_initializer in [
|
if prototype_initializer in [
|
||||||
'stratified_mean', 'stratified_random'
|
"stratified_mean", "stratified_random"
|
||||||
]:
|
]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'`prototype_initializer`: `{prototype_initializer}` '
|
f"`prototype_initializer`: `{prototype_initializer}` "
|
||||||
'requires `data`, but `data` is not provided. '
|
"requires `data`, but `data` is not provided. "
|
||||||
'Using randomly generated data instead.')
|
"Using randomly generated data instead.")
|
||||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||||
y_train = torch.arange(kwargs_nclasses)
|
y_train = torch.arange(kwargs_nclasses)
|
||||||
if one_hot_labels:
|
if one_hot_labels:
|
||||||
@ -75,39 +75,39 @@ class Prototypes1D(_Prototypes):
|
|||||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||||
|
|
||||||
if nclasses == 1:
|
if nclasses == 1:
|
||||||
warnings.warn('Are you sure about having one class only?')
|
warnings.warn("Are you sure about having one class only?")
|
||||||
|
|
||||||
if x_train.ndim != 2:
|
if x_train.ndim != 2:
|
||||||
raise ValueError('`data[0].ndim != 2`.')
|
raise ValueError("`data[0].ndim != 2`.")
|
||||||
|
|
||||||
if y_train.ndim == 2:
|
if y_train.ndim == 2:
|
||||||
if y_train.shape[1] == 1 and one_hot_labels:
|
if y_train.shape[1] == 1 and one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `True` '
|
raise ValueError("`one_hot_labels` is set to `True` "
|
||||||
'but target labels are not one-hot-encoded.')
|
"but target labels are not one-hot-encoded.")
|
||||||
if y_train.shape[1] != 1 and not one_hot_labels:
|
if y_train.shape[1] != 1 and not one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `False` '
|
raise ValueError("`one_hot_labels` is set to `False` "
|
||||||
'but target labels in `data` '
|
"but target labels in `data` "
|
||||||
'are one-hot-encoded.')
|
"are one-hot-encoded.")
|
||||||
if y_train.ndim == 1 and one_hot_labels:
|
if y_train.ndim == 1 and one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `True` '
|
raise ValueError("`one_hot_labels` is set to `True` "
|
||||||
'but target labels are not one-hot-encoded.')
|
"but target labels are not one-hot-encoded.")
|
||||||
|
|
||||||
# Verify input dimension if `input_dim` is provided
|
# Verify input dimension if `input_dim` is provided
|
||||||
if 'input_dim' in kwargs:
|
if "input_dim" in kwargs:
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop("input_dim")
|
||||||
if input_dim != x_train.shape[1]:
|
if input_dim != x_train.shape[1]:
|
||||||
raise ValueError(f'Provided `input_dim`={input_dim} does '
|
raise ValueError(f"Provided `input_dim`={input_dim} does "
|
||||||
'not match data dimension '
|
"not match data dimension "
|
||||||
f'`data[0].shape[1]`={x_train.shape[1]}')
|
f"`data[0].shape[1]`={x_train.shape[1]}")
|
||||||
|
|
||||||
# Verify the number of classes if `nclasses` is provided
|
# Verify the number of classes if `nclasses` is provided
|
||||||
if 'nclasses' in kwargs:
|
if "nclasses" in kwargs:
|
||||||
kwargs_nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop("nclasses")
|
||||||
if kwargs_nclasses != nclasses:
|
if kwargs_nclasses != nclasses:
|
||||||
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
|
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
||||||
'not match data labels '
|
"not match data labels "
|
||||||
'`torch.unique(data[1]).shape[0]`'
|
"`torch.unique(data[1]).shape[0]`"
|
||||||
f'={nclasses}')
|
f"={nclasses}")
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user