[FEATURE] Add more initializers

This commit is contained in:
Jensun Ravichandran 2021-06-14 19:53:02 +02:00
parent 549e6a10c1
commit fc9edeaa97
3 changed files with 198 additions and 102 deletions

View File

@ -8,10 +8,10 @@ from torch.nn.parameter import Parameter
from ..utils import parse_distribution
from .initializers import (
AbstractClassAwareCompInitializer,
AbstractComponentsInitializer,
AbstractLabelsInitializer,
AbstractReasoningsInitializer,
ClassAwareCompInitializer,
LabelsInitializer,
)
@ -50,7 +50,7 @@ def removeind(ins, attr, indices):
def get_cikwargs(init, distribution):
"""Return appropriate key-word arguments for a component initializer."""
if isinstance(init, ClassAwareCompInitializer):
if isinstance(init, AbstractClassAwareCompInitializer):
cikwargs = dict(distribution=distribution)
else:
distribution = parse_distribution(distribution)
@ -69,7 +69,7 @@ class AbstractComponents(torch.nn.Module):
@property
def components(self):
"""Detached Tensor containing the components."""
return self._components.detach()
return self._components.detach().cpu()
def _register_components(self, components):
self.register_parameter("_components", Parameter(components))
@ -259,7 +259,7 @@ class ReasoningComponents(AbstractComponents):
Dimension NxCx2
"""
return self._reasonings.detach()
return self._reasonings.detach().cpu()
def _register_reasonings(self, reasonings):
self.register_parameter("_reasonings", Parameter(reasonings))

View File

@ -1,5 +1,6 @@
"""ProtoTorch code initializers"""
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Union
@ -15,6 +16,24 @@ class AbstractComponentsInitializer(ABC):
...
class LiteralCompInitializer(AbstractComponentsInitializer):
"""'Generate' the provided components.
Use this to 'generate' pre-initialized components elsewhere.
"""
def __init__(self, components):
self.components = components
def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return `self.components`."""
if not isinstance(self.components, torch.Tensor):
wmsg = f"Converting components to {torch.Tensor}..."
warnings.warn(wmsg)
self.components = torch.Tensor(self.components)
return self.components
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all dimension-aware components initializers."""
def __init__(self, shape: Union[Iterable, int]):
@ -28,88 +47,6 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
...
class DataAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all data-aware components initializers.
Components generated by data-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data = data
self.noise = noise
self.transform = transform
def generate_end_hook(self, samples):
drift = torch.rand_like(samples) * self.noise
components = self.transform(samples + drift)
return components
@abstractmethod
def generate(self, num_components: int):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
class ClassAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all class-aware components initializers.
Components generated by class-aware components initializers inherit the shape
of the provided data.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data, self.targets = parse_data_arg(data)
self.noise = noise
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
@property
@abstractmethod
def subinit_type(self) -> DataAwareCompInitializer:
...
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
components = torch.tensor([])
for k, v in distribution.items():
stratified_data = self.data[self.targets == k]
initializer = self.subinit_type(
stratified_data,
noise=self.noise,
transform=self.transform,
)
samples = initializer.generate(num_components=v)
components = torch.cat([components, samples])
return components
def __del__(self):
del self.data
del self.targets
class LiteralCompInitializer(DataAwareCompInitializer):
"""'Generate' the provided components.
Use this to 'generate' pre-initialized components from elsewhere.
"""
def generate(self, num_components: int):
"""Ignore `num_components` and simply return transformed `self.data`."""
components = self.transform(self.data)
return components
class ZerosCompInitializer(ShapeAwareCompInitializer):
"""Generate zeros corresponding to the components shape."""
def generate(self, num_components: int):
@ -163,7 +100,46 @@ class RandomNormalCompInitializer(OnesCompInitializer):
return components
class SelectionCompInitializer(DataAwareCompInitializer):
class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all data-aware components initializers.
Components generated by data-aware components initializers inherit the shape
of the provided data.
`data` has to be a torch tensor.
"""
def __init__(self,
data: torch.TensorType,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data = data
self.noise = noise
self.transform = transform
def generate_end_hook(self, samples):
drift = torch.rand_like(samples) * self.noise
components = self.transform(samples + drift)
return components
@abstractmethod
def generate(self, num_components: int):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
"""'Generate' the components from the provided data."""
def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data)
return components
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by uniformly sampling from the provided data."""
def generate(self, num_components: int):
indices = torch.LongTensor(num_components).random_(0, len(self.data))
@ -172,7 +148,7 @@ class SelectionCompInitializer(DataAwareCompInitializer):
return components
class MeanCompInitializer(DataAwareCompInitializer):
class MeanCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by computing the mean of the provided data."""
def generate(self, num_components: int):
mean = torch.mean(self.data, dim=0)
@ -182,14 +158,74 @@ class MeanCompInitializer(DataAwareCompInitializer):
return components
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
"""Abstract class for all class-aware components initializers.
Components generated by class-aware components initializers inherit the shape
of the provided data.
`data` could be a torch Dataset or DataLoader or a list/tuple of data and
target tensors.
"""
def __init__(self,
data,
noise: float = 0.0,
transform: callable = torch.nn.Identity()):
self.data, self.targets = parse_data_arg(data)
self.noise = noise
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple] = []):
...
return self.generate_end_hook(...)
def __del__(self):
del self.data
del self.targets
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
"""'Generate' components from provided data and requested distribution."""
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `distribution` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data)
return components
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
"""Abstract class for all stratified components initializers."""
@property
@abstractmethod
def subinit_type(self) -> AbstractDataAwareCompInitializer:
...
def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution)
components = torch.tensor([])
for k, v in distribution.items():
stratified_data = self.data[self.targets == k]
initializer = self.subinit_type(
stratified_data,
noise=self.noise,
transform=self.transform,
)
samples = initializer.generate(num_components=v)
components = torch.cat([components, samples])
return components
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components using stratified sampling from the provided data."""
@property
def subinit_type(self):
return SelectionCompInitializer
class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components at stratified means of the provided data."""
@property
def subinit_type(self):
@ -204,6 +240,38 @@ class AbstractLabelsInitializer(ABC):
...
class LiteralLabelsInitializer(AbstractLabelsInitializer):
"""'Generate' the provided labels.
Use this to 'generate' pre-initialized labels elsewhere.
"""
def __init__(self, labels):
self.labels = labels
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `distribution` and simply return `self.labels`.
Convert to long tensor, if necessary.
"""
labels = self.labels
if not isinstance(labels, torch.LongTensor):
wmsg = f"Converting labels to {torch.LongTensor}..."
warnings.warn(wmsg)
labels = torch.LongTensor(labels)
return labels
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
"""'Generate' the labels from a torch Dataset."""
def __init__(self, data):
self.data, self.targets = parse_data_arg(data)
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `num_components` and simply return `self.targets`."""
return self.targets
class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]):
@ -248,6 +316,27 @@ class AbstractReasoningsInitializer(ABC):
return generate_end_hook(...)
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
"""'Generate' the provided reasonings.
Use this to 'generate' pre-initialized reasonings elsewhere.
"""
def __init__(self, reasonings, **kwargs):
super().__init__(**kwargs)
self.reasonings = reasonings
def generate(self, distribution: Union[dict, list, tuple] = []):
"""Ignore `distributuion` and simply return self.reasonings."""
reasonings = self.reasonings
if not isinstance(reasonings, torch.Tensor):
wmsg = f"Converting reasonings to {torch.Tensor}..."
warnings.warn(wmsg)
reasonings = torch.Tensor(reasonings)
reasonings = self.generate_end_hook(reasonings)
return reasonings
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]):
@ -292,23 +381,28 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
# Aliases - Components
ZCI = ZerosCompInitializer
OCI = OnesCompInitializer
CACI = ClassAwareCompInitializer
DACI = DataAwareCompInitializer
FVCI = FillValueCompInitializer
LCI = LiteralCompInitializer
UCI = UniformCompInitializer
MCI = MeanCompInitializer
OCI = OnesCompInitializer
RNCI = RandomNormalCompInitializer
SCI = SelectionCompInitializer
MCI = MeanCompInitializer
SSCI = StratifiedSelectionCompInitializer
SMCI = StratifiedMeanCompInitializer
SSCI = StratifiedSelectionCompInitializer
UCI = UniformCompInitializer
ZCI = ZerosCompInitializer
# Aliases - Labels
DLI = DataAwareLabelsInitializer
LI = LabelsInitializer
LLI = LiteralLabelsInitializer
OHLI = OneHotLabelsInitializer
# Aliases - Reasonings
ZRI = ZerosReasoningsInitializer
LRI = LiteralReasoningsInitializer
ORI = OnesReasoningsInitializer
RRI = RandomReasoningsInitializer
PPRI = PurePositiveReasoningsInitializer
RRI = RandomReasoningsInitializer
ZRI = ZerosReasoningsInitializer

View File

@ -67,17 +67,19 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
elif isinstance(user_distribution, list):
return distribution_from_list(user_distribution, clabels)
else:
msg = f"`distribution` not understood." \
msg = f"`distribution` was not understood." \
f"You have provided: {user_distribution}."
raise ValueError(msg)
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
"""Return data and target as torch tensors."""
if isinstance(data_arg, Dataset):
ds_size = len(data_arg)
data_arg = DataLoader(data_arg, batch_size=ds_size)
loader = DataLoader(data_arg, batch_size=ds_size)
data, targets = next(iter(loader))
if isinstance(data_arg, DataLoader):
elif isinstance(data_arg, DataLoader):
data = torch.tensor([])
targets = torch.tensor([])
for x, y in data_arg:
@ -87,11 +89,11 @@ def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
assert len(data_arg) == 2
data, targets = data_arg
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}."
wmsg = f"Converting data to {torch.Tensor}..."
warnings.warn(wmsg)
data = torch.Tensor(data)
if not isinstance(targets, torch.LongTensor):
wmsg = f"Converting targets to {torch.LongTensor}."
wmsg = f"Converting targets to {torch.LongTensor}..."
warnings.warn(wmsg)
targets = torch.LongTensor(targets)
return data, targets