Update Prototypes1D
This commit is contained in:
parent
e063625486
commit
a167565857
@ -2,8 +2,11 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from prototorch.functions.competitions import wtac
|
||||||
|
from prototorch.functions.distances import sed
|
||||||
from prototorch.functions.initializers import get_initializer
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
|
||||||
|
|
||||||
@ -12,14 +15,17 @@ class _Prototypes(torch.nn.Module):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _check_prototype_distribution(self):
|
def _validate_prototype_distribution(self):
|
||||||
if 0 in self.prototype_distribution:
|
if 0 in self.prototype_distribution:
|
||||||
warnings.warn('Are you sure about the 0 in '
|
warnings.warn('Are you sure about the `0` in '
|
||||||
'`prototype_distribution`?')
|
'`prototype_distribution`?')
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.prototypes, self.prototype_labels
|
||||||
|
|
||||||
|
|
||||||
class Prototypes1D(_Prototypes):
|
class Prototypes1D(_Prototypes):
|
||||||
r"""Create a learnable set of one-dimensional prototypes.
|
r"""Create a learnable set of one-dimensional prototypes.
|
||||||
@ -63,10 +69,12 @@ class Prototypes1D(_Prototypes):
|
|||||||
prototype_distribution=None,
|
prototype_distribution=None,
|
||||||
data=None,
|
data=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
one_hot_labels=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
# Convert torch tensors to python lists before processing
|
# Convert tensors to python lists before processing
|
||||||
if torch.is_tensor(prototype_distribution):
|
if prototype_distribution is not None:
|
||||||
|
if not isinstance(prototype_distribution, list):
|
||||||
prototype_distribution = prototype_distribution.tolist()
|
prototype_distribution = prototype_distribution.tolist()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
@ -74,13 +82,13 @@ class Prototypes1D(_Prototypes):
|
|||||||
raise NameError('`input_dim` required if '
|
raise NameError('`input_dim` required if '
|
||||||
'no `data` is provided.')
|
'no `data` is provided.')
|
||||||
if prototype_distribution:
|
if prototype_distribution:
|
||||||
nclasses = sum(prototype_distribution)
|
kwargs_nclasses = sum(prototype_distribution)
|
||||||
else:
|
else:
|
||||||
if 'nclasses' not in kwargs:
|
if 'nclasses' not in kwargs:
|
||||||
raise NameError('`prototype_distribution` required if '
|
raise NameError('`prototype_distribution` required if '
|
||||||
'both `data` and `nclasses` are not '
|
'both `data` and `nclasses` are not '
|
||||||
'provided.')
|
'provided.')
|
||||||
nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop('nclasses')
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop('input_dim')
|
||||||
if prototype_initializer in [
|
if prototype_initializer in [
|
||||||
'stratified_mean', 'stratified_random'
|
'stratified_mean', 'stratified_random'
|
||||||
@ -89,18 +97,35 @@ class Prototypes1D(_Prototypes):
|
|||||||
f'`prototype_initializer`: `{prototype_initializer}` '
|
f'`prototype_initializer`: `{prototype_initializer}` '
|
||||||
'requires `data`, but `data` is not provided. '
|
'requires `data`, but `data` is not provided. '
|
||||||
'Using randomly generated data instead.')
|
'Using randomly generated data instead.')
|
||||||
x_train = torch.rand(nclasses, input_dim)
|
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||||
y_train = torch.arange(nclasses)
|
y_train = torch.arange(kwargs_nclasses)
|
||||||
|
if one_hot_labels:
|
||||||
|
y_train = torch.eye(kwargs_nclasses)[y_train]
|
||||||
data = [x_train, y_train]
|
data = [x_train, y_train]
|
||||||
|
|
||||||
x_train, y_train = data
|
x_train, y_train = data
|
||||||
x_train = torch.as_tensor(x_train).type(dtype)
|
x_train = torch.as_tensor(x_train).type(dtype)
|
||||||
y_train = torch.as_tensor(y_train).type(dtype)
|
y_train = torch.as_tensor(y_train).type(torch.int)
|
||||||
nclasses = torch.unique(y_train).shape[0]
|
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||||
|
|
||||||
|
if nclasses == 1:
|
||||||
|
warnings.warn('Are you sure about having one class only?')
|
||||||
|
|
||||||
if x_train.ndim != 2:
|
if x_train.ndim != 2:
|
||||||
raise ValueError('`data[0].ndim != 2`.')
|
raise ValueError('`data[0].ndim != 2`.')
|
||||||
|
|
||||||
|
if y_train.ndim == 2:
|
||||||
|
if y_train.shape[1] == 1 and one_hot_labels:
|
||||||
|
raise ValueError('`one_hot_labels` is set to `True` '
|
||||||
|
'but target labels are not one-hot-encoded.')
|
||||||
|
if y_train.shape[1] != 1 and not one_hot_labels:
|
||||||
|
raise ValueError('`one_hot_labels` is set to `False` '
|
||||||
|
'but target labels in `data` '
|
||||||
|
'are one-hot-encoded.')
|
||||||
|
if y_train.ndim == 1 and one_hot_labels:
|
||||||
|
raise ValueError('`one_hot_labels` is set to `True` '
|
||||||
|
'but target labels are not one-hot-encoded.')
|
||||||
|
|
||||||
# Verify input dimension if `input_dim` is provided
|
# Verify input dimension if `input_dim` is provided
|
||||||
if 'input_dim' in kwargs:
|
if 'input_dim' in kwargs:
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop('input_dim')
|
||||||
@ -125,17 +150,16 @@ class Prototypes1D(_Prototypes):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||||
|
|
||||||
self._check_prototype_distribution()
|
self._validate_prototype_distribution()
|
||||||
|
|
||||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||||
prototypes, prototype_labels = self.prototype_initializer(
|
prototypes, prototype_labels = self.prototype_initializer(
|
||||||
x_train,
|
x_train,
|
||||||
y_train,
|
y_train,
|
||||||
prototype_distribution=self.prototype_distribution)
|
prototype_distribution=self.prototype_distribution,
|
||||||
|
one_hot=one_hot_labels,
|
||||||
|
)
|
||||||
|
|
||||||
# Register module parameters
|
# Register module parameters
|
||||||
self.prototypes = torch.nn.Parameter(prototypes)
|
self.prototypes = torch.nn.Parameter(prototypes)
|
||||||
self.prototype_labels = prototype_labels
|
self.prototype_labels = prototype_labels
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
return self.prototypes, self.prototype_labels
|
|
||||||
|
Loading…
Reference in New Issue
Block a user