[FEATURE] Add more initializers

This commit is contained in:
Jensun Ravichandran 2021-06-14 19:53:02 +02:00
parent 549e6a10c1
commit fc9edeaa97
3 changed files with 198 additions and 102 deletions

View File

@ -8,10 +8,10 @@ from torch.nn.parameter import Parameter
from ..utils import parse_distribution from ..utils import parse_distribution
from .initializers import ( from .initializers import (
AbstractClassAwareCompInitializer,
AbstractComponentsInitializer, AbstractComponentsInitializer,
AbstractLabelsInitializer, AbstractLabelsInitializer,
AbstractReasoningsInitializer, AbstractReasoningsInitializer,
ClassAwareCompInitializer,
LabelsInitializer, LabelsInitializer,
) )
@ -50,7 +50,7 @@ def removeind(ins, attr, indices):
def get_cikwargs(init, distribution): def get_cikwargs(init, distribution):
"""Return appropriate key-word arguments for a component initializer.""" """Return appropriate key-word arguments for a component initializer."""
if isinstance(init, ClassAwareCompInitializer): if isinstance(init, AbstractClassAwareCompInitializer):
cikwargs = dict(distribution=distribution) cikwargs = dict(distribution=distribution)
else: else:
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
@ -69,7 +69,7 @@ class AbstractComponents(torch.nn.Module):
@property @property
def components(self): def components(self):
"""Detached Tensor containing the components.""" """Detached Tensor containing the components."""
return self._components.detach() return self._components.detach().cpu()
def _register_components(self, components): def _register_components(self, components):
self.register_parameter("_components", Parameter(components)) self.register_parameter("_components", Parameter(components))
@ -259,7 +259,7 @@ class ReasoningComponents(AbstractComponents):
Dimension NxCx2 Dimension NxCx2
""" """
return self._reasonings.detach() return self._reasonings.detach().cpu()
def _register_reasonings(self, reasonings): def _register_reasonings(self, reasonings):
self.register_parameter("_reasonings", Parameter(reasonings)) self.register_parameter("_reasonings", Parameter(reasonings))

View File

@ -1,5 +1,6 @@
"""ProtoTorch code initializers""" """ProtoTorch code initializers"""
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import Union from typing import Union
@ -15,6 +16,24 @@ class AbstractComponentsInitializer(ABC):
... ...
class LiteralCompInitializer(AbstractComponentsInitializer):
"""'Generate' the provided components.
Use this to 'generate' pre-initialized components elsewhere.
"""
def __init__(self, components):
self.components = components
def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return `self.components`."""
if not isinstance(self.components, torch.Tensor):
wmsg = f"Converting components to {torch.Tensor}..."
warnings.warn(wmsg)
self.components = torch.Tensor(self.components)
return self.components
class ShapeAwareCompInitializer(AbstractComponentsInitializer): class ShapeAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all dimension-aware components initializers.""" """Abstract class for all dimension-aware components initializers."""
def __init__(self, shape: Union[Iterable, int]): def __init__(self, shape: Union[Iterable, int]):
@ -28,88 +47,6 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
... ...
class DataAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all data-aware components initializers.
Components generated by data-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data = data
self.noise = noise
self.transform = transform
def generate_end_hook(self, samples):
drift = torch.rand_like(samples) * self.noise
components = self.transform(samples + drift)
return components
@abstractmethod
def generate(self, num_components: int):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
class ClassAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all class-aware components initializers.
Components generated by class-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data, self.targets = parse_data_arg(data)
self.noise = noise
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
@property
@abstractmethod
def subinit_type(self) -> DataAwareCompInitializer:
...
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
components = torch.tensor([])
for k, v in distribution.items():
stratified_data = self.data[self.targets == k]
initializer = self.subinit_type(
stratified_data,
noise=self.noise,
transform=self.transform,
)
samples = initializer.generate(num_components=v)
components = torch.cat([components, samples])
return components
def __del__(self):
del self.data
del self.targets
class LiteralCompInitializer(DataAwareCompInitializer):
"""'Generate' the provided components.
Use this to 'generate' pre-initialized components from elsewhere.
"""
def generate(self, num_components: int):
"""Ignore `num_components` and simply return transformed `self.data`."""
components = self.transform(self.data)
return components
class ZerosCompInitializer(ShapeAwareCompInitializer): class ZerosCompInitializer(ShapeAwareCompInitializer):
"""Generate zeros corresponding to the components shape.""" """Generate zeros corresponding to the components shape."""
def generate(self, num_components: int): def generate(self, num_components: int):
@ -163,7 +100,46 @@ class RandomNormalCompInitializer(OnesCompInitializer):
return components return components
class SelectionCompInitializer(DataAwareCompInitializer): class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all data-aware components initializers.
Components generated by data-aware components initializers inherit the shape
of the provided data.
`data` has to be a torch tensor.
"""
def __init__(self,
data: torch.TensorType,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data = data
self.noise = noise
self.transform = transform
def generate_end_hook(self, samples):
drift = torch.rand_like(samples) * self.noise
components = self.transform(samples + drift)
return components
@abstractmethod
def generate(self, num_components: int):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
"""'Generate' the components from the provided data."""
def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data)
return components
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by uniformly sampling from the provided data.""" """Generate components by uniformly sampling from the provided data."""
def generate(self, num_components: int): def generate(self, num_components: int):
indices = torch.LongTensor(num_components).random_(0, len(self.data)) indices = torch.LongTensor(num_components).random_(0, len(self.data))
@ -172,7 +148,7 @@ class SelectionCompInitializer(DataAwareCompInitializer):
return components return components
class MeanCompInitializer(DataAwareCompInitializer): class MeanCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by computing the mean of the provided data.""" """Generate components by computing the mean of the provided data."""
def generate(self, num_components: int): def generate(self, num_components: int):
mean = torch.mean(self.data, dim=0) mean = torch.mean(self.data, dim=0)
@ -182,14 +158,74 @@ class MeanCompInitializer(DataAwareCompInitializer):
return components return components
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer): class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
"""Abstract class for all class-aware components initializers.
Components generated by class-aware components initializers inherit the shape
of the provided data.
`data` could be a torch Dataset or DataLoader or a list/tuple of data and
target tensors.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data, self.targets = parse_data_arg(data)
self.noise = noise
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple] = []):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
del self.targets
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
"""'Generate' components from provided data and requested distribution."""
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `distribution` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data)
return components
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
"""Abstract class for all stratified components initializers."""
@property
@abstractmethod
def subinit_type(self) -> AbstractDataAwareCompInitializer:
...
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
components = torch.tensor([])
for k, v in distribution.items():
stratified_data = self.data[self.targets == k]
initializer = self.subinit_type(
stratified_data,
noise=self.noise,
transform=self.transform,
)
samples = initializer.generate(num_components=v)
components = torch.cat([components, samples])
return components
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components using stratified sampling from the provided data.""" """Generate components using stratified sampling from the provided data."""
@property @property
def subinit_type(self): def subinit_type(self):
return SelectionCompInitializer return SelectionCompInitializer
class StratifiedMeanCompInitializer(ClassAwareCompInitializer): class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components at stratified means of the provided data.""" """Generate components at stratified means of the provided data."""
@property @property
def subinit_type(self): def subinit_type(self):
@ -204,6 +240,38 @@ class AbstractLabelsInitializer(ABC):
... ...
class LiteralLabelsInitializer(AbstractLabelsInitializer):
"""'Generate' the provided labels.
Use this to 'generate' pre-initialized labels elsewhere.
"""
def __init__(self, labels):
self.labels = labels
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `distribution` and simply return `self.labels`.
Convert to long tensor, if necessary.
"""
labels = self.labels
if not isinstance(labels, torch.LongTensor):
wmsg = f"Converting labels to {torch.LongTensor}..."
warnings.warn(wmsg)
labels = torch.LongTensor(labels)
return labels
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
"""'Generate' the labels from a torch Dataset."""
def __init__(self, data):
self.data, self.targets = parse_data_arg(data)
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `num_components` and simply return `self.targets`."""
return self.targets
class LabelsInitializer(AbstractLabelsInitializer): class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels from `distribution`.""" """Generate labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
@ -248,6 +316,27 @@ class AbstractReasoningsInitializer(ABC):
return generate_end_hook(...) return generate_end_hook(...)
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
"""'Generate' the provided reasonings.
Use this to 'generate' pre-initialized reasonings elsewhere.
"""
def __init__(self, reasonings, **kwargs):
super().__init__(**kwargs)
self.reasonings = reasonings
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `distributuion` and simply return self.reasonings."""
reasonings = self.reasonings
if not isinstance(reasonings, torch.Tensor):
wmsg = f"Converting reasonings to {torch.Tensor}..."
warnings.warn(wmsg)
reasonings = torch.Tensor(reasonings)
reasonings = self.generate_end_hook(reasonings)
return reasonings
class ZerosReasoningsInitializer(AbstractReasoningsInitializer): class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros.""" """Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
@ -292,23 +381,28 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
# Aliases - Components # Aliases - Components
ZCI = ZerosCompInitializer CACI = ClassAwareCompInitializer
OCI = OnesCompInitializer DACI = DataAwareCompInitializer
FVCI = FillValueCompInitializer FVCI = FillValueCompInitializer
LCI = LiteralCompInitializer LCI = LiteralCompInitializer
UCI = UniformCompInitializer MCI = MeanCompInitializer
OCI = OnesCompInitializer
RNCI = RandomNormalCompInitializer RNCI = RandomNormalCompInitializer
SCI = SelectionCompInitializer SCI = SelectionCompInitializer
MCI = MeanCompInitializer
SSCI = StratifiedSelectionCompInitializer
SMCI = StratifiedMeanCompInitializer SMCI = StratifiedMeanCompInitializer
SSCI = StratifiedSelectionCompInitializer
UCI = UniformCompInitializer
ZCI = ZerosCompInitializer
# Aliases - Labels # Aliases - Labels
DLI = DataAwareLabelsInitializer
LI = LabelsInitializer LI = LabelsInitializer
LLI = LiteralLabelsInitializer
OHLI = OneHotLabelsInitializer OHLI = OneHotLabelsInitializer
# Aliases - Reasonings # Aliases - Reasonings
ZRI = ZerosReasoningsInitializer LRI = LiteralReasoningsInitializer
ORI = OnesReasoningsInitializer ORI = OnesReasoningsInitializer
RRI = RandomReasoningsInitializer
PPRI = PurePositiveReasoningsInitializer PPRI = PurePositiveReasoningsInitializer
RRI = RandomReasoningsInitializer
ZRI = ZerosReasoningsInitializer

View File

@ -67,17 +67,19 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
elif isinstance(user_distribution, list): elif isinstance(user_distribution, list):
return distribution_from_list(user_distribution, clabels) return distribution_from_list(user_distribution, clabels)
else: else:
msg = f"`distribution` not understood." \ msg = f"`distribution` was not understood." \
f"You have provided: {user_distribution}." f"You have provided: {user_distribution}."
raise ValueError(msg) raise ValueError(msg)
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]): def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
"""Return data and target as torch tensors."""
if isinstance(data_arg, Dataset): if isinstance(data_arg, Dataset):
ds_size = len(data_arg) ds_size = len(data_arg)
data_arg = DataLoader(data_arg, batch_size=ds_size) loader = DataLoader(data_arg, batch_size=ds_size)
data, targets = next(iter(loader))
if isinstance(data_arg, DataLoader): elif isinstance(data_arg, DataLoader):
data = torch.tensor([]) data = torch.tensor([])
targets = torch.tensor([]) targets = torch.tensor([])
for x, y in data_arg: for x, y in data_arg:
@ -87,11 +89,11 @@ def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
assert len(data_arg) == 2 assert len(data_arg) == 2
data, targets = data_arg data, targets = data_arg
if not isinstance(data, torch.Tensor): if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}." wmsg = f"Converting data to {torch.Tensor}..."
warnings.warn(wmsg) warnings.warn(wmsg)
data = torch.Tensor(data) data = torch.Tensor(data)
if not isinstance(targets, torch.LongTensor): if not isinstance(targets, torch.LongTensor):
wmsg = f"Converting targets to {torch.LongTensor}." wmsg = f"Converting targets to {torch.LongTensor}..."
warnings.warn(wmsg) warnings.warn(wmsg)
targets = torch.LongTensor(targets) targets = torch.LongTensor(targets)
return data, targets return data, targets