[FEATURE] Add more initializers
This commit is contained in:
parent
549e6a10c1
commit
fc9edeaa97
@ -8,10 +8,10 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from ..utils import parse_distribution
|
from ..utils import parse_distribution
|
||||||
from .initializers import (
|
from .initializers import (
|
||||||
|
AbstractClassAwareCompInitializer,
|
||||||
AbstractComponentsInitializer,
|
AbstractComponentsInitializer,
|
||||||
AbstractLabelsInitializer,
|
AbstractLabelsInitializer,
|
||||||
AbstractReasoningsInitializer,
|
AbstractReasoningsInitializer,
|
||||||
ClassAwareCompInitializer,
|
|
||||||
LabelsInitializer,
|
LabelsInitializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ def removeind(ins, attr, indices):
|
|||||||
|
|
||||||
def get_cikwargs(init, distribution):
|
def get_cikwargs(init, distribution):
|
||||||
"""Return appropriate key-word arguments for a component initializer."""
|
"""Return appropriate key-word arguments for a component initializer."""
|
||||||
if isinstance(init, ClassAwareCompInitializer):
|
if isinstance(init, AbstractClassAwareCompInitializer):
|
||||||
cikwargs = dict(distribution=distribution)
|
cikwargs = dict(distribution=distribution)
|
||||||
else:
|
else:
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
@ -69,7 +69,7 @@ class AbstractComponents(torch.nn.Module):
|
|||||||
@property
|
@property
|
||||||
def components(self):
|
def components(self):
|
||||||
"""Detached Tensor containing the components."""
|
"""Detached Tensor containing the components."""
|
||||||
return self._components.detach()
|
return self._components.detach().cpu()
|
||||||
|
|
||||||
def _register_components(self, components):
|
def _register_components(self, components):
|
||||||
self.register_parameter("_components", Parameter(components))
|
self.register_parameter("_components", Parameter(components))
|
||||||
@ -259,7 +259,7 @@ class ReasoningComponents(AbstractComponents):
|
|||||||
Dimension NxCx2
|
Dimension NxCx2
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self._reasonings.detach()
|
return self._reasonings.detach().cpu()
|
||||||
|
|
||||||
def _register_reasonings(self, reasonings):
|
def _register_reasonings(self, reasonings):
|
||||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""ProtoTorch code initializers"""
|
"""ProtoTorch code initializers"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -15,6 +16,24 @@ class AbstractComponentsInitializer(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralCompInitializer(AbstractComponentsInitializer):
|
||||||
|
"""'Generate' the provided components.
|
||||||
|
|
||||||
|
Use this to 'generate' pre-initialized components elsewhere.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, components):
|
||||||
|
self.components = components
|
||||||
|
|
||||||
|
def generate(self, num_components: int = 0):
|
||||||
|
"""Ignore `num_components` and simply return `self.components`."""
|
||||||
|
if not isinstance(self.components, torch.Tensor):
|
||||||
|
wmsg = f"Converting components to {torch.Tensor}..."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
self.components = torch.Tensor(self.components)
|
||||||
|
return self.components
|
||||||
|
|
||||||
|
|
||||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||||
"""Abstract class for all dimension-aware components initializers."""
|
"""Abstract class for all dimension-aware components initializers."""
|
||||||
def __init__(self, shape: Union[Iterable, int]):
|
def __init__(self, shape: Union[Iterable, int]):
|
||||||
@ -28,88 +47,6 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class DataAwareCompInitializer(AbstractComponentsInitializer):
|
|
||||||
"""Abstract class for all data-aware components initializers.
|
|
||||||
|
|
||||||
Components generated by data-aware components initializers inherit the shape
|
|
||||||
of the provided data.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
data,
|
|
||||||
noise: float = 0.0,
|
|
||||||
transform: callable = torch.nn.Identity()):
|
|
||||||
self.data = data
|
|
||||||
self.noise = noise
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def generate_end_hook(self, samples):
|
|
||||||
drift = torch.rand_like(samples) * self.noise
|
|
||||||
components = self.transform(samples + drift)
|
|
||||||
return components
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate(self, num_components: int):
|
|
||||||
...
|
|
||||||
return self.generate_end_hook(...)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
del self.data
|
|
||||||
|
|
||||||
|
|
||||||
class ClassAwareCompInitializer(AbstractComponentsInitializer):
|
|
||||||
"""Abstract class for all class-aware components initializers.
|
|
||||||
|
|
||||||
Components generated by class-aware components initializers inherit the shape
|
|
||||||
of the provided data.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
data,
|
|
||||||
noise: float = 0.0,
|
|
||||||
transform: callable = torch.nn.Identity()):
|
|
||||||
self.data, self.targets = parse_data_arg(data)
|
|
||||||
self.noise = noise
|
|
||||||
self.transform = transform
|
|
||||||
self.clabels = torch.unique(self.targets).int().tolist()
|
|
||||||
self.num_classes = len(self.clabels)
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def subinit_type(self) -> DataAwareCompInitializer:
|
|
||||||
...
|
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
|
||||||
distribution = parse_distribution(distribution)
|
|
||||||
components = torch.tensor([])
|
|
||||||
for k, v in distribution.items():
|
|
||||||
stratified_data = self.data[self.targets == k]
|
|
||||||
initializer = self.subinit_type(
|
|
||||||
stratified_data,
|
|
||||||
noise=self.noise,
|
|
||||||
transform=self.transform,
|
|
||||||
)
|
|
||||||
samples = initializer.generate(num_components=v)
|
|
||||||
components = torch.cat([components, samples])
|
|
||||||
return components
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
del self.data
|
|
||||||
del self.targets
|
|
||||||
|
|
||||||
|
|
||||||
class LiteralCompInitializer(DataAwareCompInitializer):
|
|
||||||
"""'Generate' the provided components.
|
|
||||||
|
|
||||||
Use this to 'generate' pre-initialized components from elsewhere.
|
|
||||||
|
|
||||||
"""
|
|
||||||
def generate(self, num_components: int):
|
|
||||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
|
||||||
components = self.transform(self.data)
|
|
||||||
return components
|
|
||||||
|
|
||||||
|
|
||||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||||
"""Generate zeros corresponding to the components shape."""
|
"""Generate zeros corresponding to the components shape."""
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
@ -163,7 +100,46 @@ class RandomNormalCompInitializer(OnesCompInitializer):
|
|||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
class SelectionCompInitializer(DataAwareCompInitializer):
|
class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
||||||
|
"""Abstract class for all data-aware components initializers.
|
||||||
|
|
||||||
|
Components generated by data-aware components initializers inherit the shape
|
||||||
|
of the provided data.
|
||||||
|
|
||||||
|
`data` has to be a torch tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
data: torch.TensorType,
|
||||||
|
noise: float = 0.0,
|
||||||
|
transform: callable = torch.nn.Identity()):
|
||||||
|
self.data = data
|
||||||
|
self.noise = noise
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def generate_end_hook(self, samples):
|
||||||
|
drift = torch.rand_like(samples) * self.noise
|
||||||
|
components = self.transform(samples + drift)
|
||||||
|
return components
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, num_components: int):
|
||||||
|
...
|
||||||
|
return self.generate_end_hook(...)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.data
|
||||||
|
|
||||||
|
|
||||||
|
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
|
"""'Generate' the components from the provided data."""
|
||||||
|
def generate(self, num_components: int = 0):
|
||||||
|
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||||
|
components = self.generate_end_hook(self.data)
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""Generate components by uniformly sampling from the provided data."""
|
"""Generate components by uniformly sampling from the provided data."""
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||||
@ -172,7 +148,7 @@ class SelectionCompInitializer(DataAwareCompInitializer):
|
|||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
class MeanCompInitializer(DataAwareCompInitializer):
|
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""Generate components by computing the mean of the provided data."""
|
"""Generate components by computing the mean of the provided data."""
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
mean = torch.mean(self.data, dim=0)
|
mean = torch.mean(self.data, dim=0)
|
||||||
@ -182,14 +158,74 @@ class MeanCompInitializer(DataAwareCompInitializer):
|
|||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
|
class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
|
"""Abstract class for all class-aware components initializers.
|
||||||
|
|
||||||
|
Components generated by class-aware components initializers inherit the shape
|
||||||
|
of the provided data.
|
||||||
|
|
||||||
|
`data` could be a torch Dataset or DataLoader or a list/tuple of data and
|
||||||
|
target tensors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
data,
|
||||||
|
noise: float = 0.0,
|
||||||
|
transform: callable = torch.nn.Identity()):
|
||||||
|
self.data, self.targets = parse_data_arg(data)
|
||||||
|
self.noise = noise
|
||||||
|
self.transform = transform
|
||||||
|
self.clabels = torch.unique(self.targets).int().tolist()
|
||||||
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||||
|
...
|
||||||
|
return self.generate_end_hook(...)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.data
|
||||||
|
del self.targets
|
||||||
|
|
||||||
|
|
||||||
|
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||||
|
"""'Generate' components from provided data and requested distribution."""
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||||
|
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||||
|
components = self.generate_end_hook(self.data)
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||||
|
"""Abstract class for all stratified components initializers."""
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def subinit_type(self) -> AbstractDataAwareCompInitializer:
|
||||||
|
...
|
||||||
|
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
|
components = torch.tensor([])
|
||||||
|
for k, v in distribution.items():
|
||||||
|
stratified_data = self.data[self.targets == k]
|
||||||
|
initializer = self.subinit_type(
|
||||||
|
stratified_data,
|
||||||
|
noise=self.noise,
|
||||||
|
transform=self.transform,
|
||||||
|
)
|
||||||
|
samples = initializer.generate(num_components=v)
|
||||||
|
components = torch.cat([components, samples])
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||||
"""Generate components using stratified sampling from the provided data."""
|
"""Generate components using stratified sampling from the provided data."""
|
||||||
@property
|
@property
|
||||||
def subinit_type(self):
|
def subinit_type(self):
|
||||||
return SelectionCompInitializer
|
return SelectionCompInitializer
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
|
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||||
"""Generate components at stratified means of the provided data."""
|
"""Generate components at stratified means of the provided data."""
|
||||||
@property
|
@property
|
||||||
def subinit_type(self):
|
def subinit_type(self):
|
||||||
@ -204,6 +240,38 @@ class AbstractLabelsInitializer(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
||||||
|
"""'Generate' the provided labels.
|
||||||
|
|
||||||
|
Use this to 'generate' pre-initialized labels elsewhere.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, labels):
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||||
|
"""Ignore `distribution` and simply return `self.labels`.
|
||||||
|
|
||||||
|
Convert to long tensor, if necessary.
|
||||||
|
"""
|
||||||
|
labels = self.labels
|
||||||
|
if not isinstance(labels, torch.LongTensor):
|
||||||
|
wmsg = f"Converting labels to {torch.LongTensor}..."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
labels = torch.LongTensor(labels)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||||
|
"""'Generate' the labels from a torch Dataset."""
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data, self.targets = parse_data_arg(data)
|
||||||
|
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||||
|
"""Ignore `num_components` and simply return `self.targets`."""
|
||||||
|
return self.targets
|
||||||
|
|
||||||
|
|
||||||
class LabelsInitializer(AbstractLabelsInitializer):
|
class LabelsInitializer(AbstractLabelsInitializer):
|
||||||
"""Generate labels from `distribution`."""
|
"""Generate labels from `distribution`."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
@ -248,6 +316,27 @@ class AbstractReasoningsInitializer(ABC):
|
|||||||
return generate_end_hook(...)
|
return generate_end_hook(...)
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
|
"""'Generate' the provided reasonings.
|
||||||
|
|
||||||
|
Use this to 'generate' pre-initialized reasonings elsewhere.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, reasonings, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.reasonings = reasonings
|
||||||
|
|
||||||
|
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||||
|
"""Ignore `distributuion` and simply return self.reasonings."""
|
||||||
|
reasonings = self.reasonings
|
||||||
|
if not isinstance(reasonings, torch.Tensor):
|
||||||
|
wmsg = f"Converting reasonings to {torch.Tensor}..."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
reasonings = torch.Tensor(reasonings)
|
||||||
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
|
return reasonings
|
||||||
|
|
||||||
|
|
||||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with zeros."""
|
"""Reasonings are all initialized with zeros."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
@ -292,23 +381,28 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
|
|
||||||
# Aliases - Components
|
# Aliases - Components
|
||||||
ZCI = ZerosCompInitializer
|
CACI = ClassAwareCompInitializer
|
||||||
OCI = OnesCompInitializer
|
DACI = DataAwareCompInitializer
|
||||||
FVCI = FillValueCompInitializer
|
FVCI = FillValueCompInitializer
|
||||||
LCI = LiteralCompInitializer
|
LCI = LiteralCompInitializer
|
||||||
UCI = UniformCompInitializer
|
MCI = MeanCompInitializer
|
||||||
|
OCI = OnesCompInitializer
|
||||||
RNCI = RandomNormalCompInitializer
|
RNCI = RandomNormalCompInitializer
|
||||||
SCI = SelectionCompInitializer
|
SCI = SelectionCompInitializer
|
||||||
MCI = MeanCompInitializer
|
|
||||||
SSCI = StratifiedSelectionCompInitializer
|
|
||||||
SMCI = StratifiedMeanCompInitializer
|
SMCI = StratifiedMeanCompInitializer
|
||||||
|
SSCI = StratifiedSelectionCompInitializer
|
||||||
|
UCI = UniformCompInitializer
|
||||||
|
ZCI = ZerosCompInitializer
|
||||||
|
|
||||||
# Aliases - Labels
|
# Aliases - Labels
|
||||||
|
DLI = DataAwareLabelsInitializer
|
||||||
LI = LabelsInitializer
|
LI = LabelsInitializer
|
||||||
|
LLI = LiteralLabelsInitializer
|
||||||
OHLI = OneHotLabelsInitializer
|
OHLI = OneHotLabelsInitializer
|
||||||
|
|
||||||
# Aliases - Reasonings
|
# Aliases - Reasonings
|
||||||
ZRI = ZerosReasoningsInitializer
|
LRI = LiteralReasoningsInitializer
|
||||||
ORI = OnesReasoningsInitializer
|
ORI = OnesReasoningsInitializer
|
||||||
RRI = RandomReasoningsInitializer
|
|
||||||
PPRI = PurePositiveReasoningsInitializer
|
PPRI = PurePositiveReasoningsInitializer
|
||||||
|
RRI = RandomReasoningsInitializer
|
||||||
|
ZRI = ZerosReasoningsInitializer
|
||||||
|
@ -67,17 +67,19 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
|
|||||||
elif isinstance(user_distribution, list):
|
elif isinstance(user_distribution, list):
|
||||||
return distribution_from_list(user_distribution, clabels)
|
return distribution_from_list(user_distribution, clabels)
|
||||||
else:
|
else:
|
||||||
msg = f"`distribution` not understood." \
|
msg = f"`distribution` was not understood." \
|
||||||
f"You have provided: {user_distribution}."
|
f"You have provided: {user_distribution}."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||||
|
"""Return data and target as torch tensors."""
|
||||||
if isinstance(data_arg, Dataset):
|
if isinstance(data_arg, Dataset):
|
||||||
ds_size = len(data_arg)
|
ds_size = len(data_arg)
|
||||||
data_arg = DataLoader(data_arg, batch_size=ds_size)
|
loader = DataLoader(data_arg, batch_size=ds_size)
|
||||||
|
data, targets = next(iter(loader))
|
||||||
|
|
||||||
if isinstance(data_arg, DataLoader):
|
elif isinstance(data_arg, DataLoader):
|
||||||
data = torch.tensor([])
|
data = torch.tensor([])
|
||||||
targets = torch.tensor([])
|
targets = torch.tensor([])
|
||||||
for x, y in data_arg:
|
for x, y in data_arg:
|
||||||
@ -87,11 +89,11 @@ def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
|||||||
assert len(data_arg) == 2
|
assert len(data_arg) == 2
|
||||||
data, targets = data_arg
|
data, targets = data_arg
|
||||||
if not isinstance(data, torch.Tensor):
|
if not isinstance(data, torch.Tensor):
|
||||||
wmsg = f"Converting data to {torch.Tensor}."
|
wmsg = f"Converting data to {torch.Tensor}..."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
data = torch.Tensor(data)
|
data = torch.Tensor(data)
|
||||||
if not isinstance(targets, torch.LongTensor):
|
if not isinstance(targets, torch.LongTensor):
|
||||||
wmsg = f"Converting targets to {torch.LongTensor}."
|
wmsg = f"Converting targets to {torch.LongTensor}..."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
targets = torch.LongTensor(targets)
|
targets = torch.LongTensor(targets)
|
||||||
return data, targets
|
return data, targets
|
||||||
|
Loading…
Reference in New Issue
Block a user