Update Prototypes1D

This commit is contained in:
blackfly 2020-04-27 12:44:19 +02:00
parent e063625486
commit a167565857

View File

@ -2,8 +2,11 @@
import warnings import warnings
import numpy as np
import torch import torch
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import sed
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
@ -12,14 +15,17 @@ class _Prototypes(torch.nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _check_prototype_distribution(self): def _validate_prototype_distribution(self):
if 0 in self.prototype_distribution: if 0 in self.prototype_distribution:
warnings.warn('Are you sure about the 0 in ' warnings.warn('Are you sure about the `0` in '
'`prototype_distribution`?') '`prototype_distribution`?')
def extra_repr(self): def extra_repr(self):
return f'prototypes.shape: {tuple(self.prototypes.shape)}' return f'prototypes.shape: {tuple(self.prototypes.shape)}'
def forward(self):
return self.prototypes, self.prototype_labels
class Prototypes1D(_Prototypes): class Prototypes1D(_Prototypes):
r"""Create a learnable set of one-dimensional prototypes. r"""Create a learnable set of one-dimensional prototypes.
@ -63,10 +69,12 @@ class Prototypes1D(_Prototypes):
prototype_distribution=None, prototype_distribution=None,
data=None, data=None,
dtype=torch.float32, dtype=torch.float32,
one_hot_labels=False,
**kwargs): **kwargs):
# Convert torch tensors to python lists before processing # Convert tensors to python lists before processing
if torch.is_tensor(prototype_distribution): if prototype_distribution is not None:
if not isinstance(prototype_distribution, list):
prototype_distribution = prototype_distribution.tolist() prototype_distribution = prototype_distribution.tolist()
if data is None: if data is None:
@ -74,13 +82,13 @@ class Prototypes1D(_Prototypes):
raise NameError('`input_dim` required if ' raise NameError('`input_dim` required if '
'no `data` is provided.') 'no `data` is provided.')
if prototype_distribution: if prototype_distribution:
nclasses = sum(prototype_distribution) kwargs_nclasses = sum(prototype_distribution)
else: else:
if 'nclasses' not in kwargs: if 'nclasses' not in kwargs:
raise NameError('`prototype_distribution` required if ' raise NameError('`prototype_distribution` required if '
'both `data` and `nclasses` are not ' 'both `data` and `nclasses` are not '
'provided.') 'provided.')
nclasses = kwargs.pop('nclasses') kwargs_nclasses = kwargs.pop('nclasses')
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop('input_dim')
if prototype_initializer in [ if prototype_initializer in [
'stratified_mean', 'stratified_random' 'stratified_mean', 'stratified_random'
@ -89,18 +97,35 @@ class Prototypes1D(_Prototypes):
f'`prototype_initializer`: `{prototype_initializer}` ' f'`prototype_initializer`: `{prototype_initializer}` '
'requires `data`, but `data` is not provided. ' 'requires `data`, but `data` is not provided. '
'Using randomly generated data instead.') 'Using randomly generated data instead.')
x_train = torch.rand(nclasses, input_dim) x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(nclasses) y_train = torch.arange(kwargs_nclasses)
if one_hot_labels:
y_train = torch.eye(kwargs_nclasses)[y_train]
data = [x_train, y_train] data = [x_train, y_train]
x_train, y_train = data x_train, y_train = data
x_train = torch.as_tensor(x_train).type(dtype) x_train = torch.as_tensor(x_train).type(dtype)
y_train = torch.as_tensor(y_train).type(dtype) y_train = torch.as_tensor(y_train).type(torch.int)
nclasses = torch.unique(y_train).shape[0] nclasses = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1:
warnings.warn('Are you sure about having one class only?')
if x_train.ndim != 2: if x_train.ndim != 2:
raise ValueError('`data[0].ndim != 2`.') raise ValueError('`data[0].ndim != 2`.')
if y_train.ndim == 2:
if y_train.shape[1] == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` '
'but target labels are not one-hot-encoded.')
if y_train.shape[1] != 1 and not one_hot_labels:
raise ValueError('`one_hot_labels` is set to `False` '
'but target labels in `data` '
'are one-hot-encoded.')
if y_train.ndim == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` '
'but target labels are not one-hot-encoded.')
# Verify input dimension if `input_dim` is provided # Verify input dimension if `input_dim` is provided
if 'input_dim' in kwargs: if 'input_dim' in kwargs:
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop('input_dim')
@ -125,17 +150,16 @@ class Prototypes1D(_Prototypes):
with torch.no_grad(): with torch.no_grad():
self.prototype_distribution = torch.tensor(prototype_distribution) self.prototype_distribution = torch.tensor(prototype_distribution)
self._check_prototype_distribution() self._validate_prototype_distribution()
self.prototype_initializer = get_initializer(prototype_initializer) self.prototype_initializer = get_initializer(prototype_initializer)
prototypes, prototype_labels = self.prototype_initializer( prototypes, prototype_labels = self.prototype_initializer(
x_train, x_train,
y_train, y_train,
prototype_distribution=self.prototype_distribution) prototype_distribution=self.prototype_distribution,
one_hot=one_hot_labels,
)
# Register module parameters # Register module parameters
self.prototypes = torch.nn.Parameter(prototypes) self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = prototype_labels self.prototype_labels = prototype_labels
def forward(self):
return self.prototypes, self.prototype_labels