[FEATURE] Add more initializers
This commit is contained in:
parent
549e6a10c1
commit
fc9edeaa97
@ -8,10 +8,10 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from ..utils import parse_distribution
|
||||
from .initializers import (
|
||||
AbstractClassAwareCompInitializer,
|
||||
AbstractComponentsInitializer,
|
||||
AbstractLabelsInitializer,
|
||||
AbstractReasoningsInitializer,
|
||||
ClassAwareCompInitializer,
|
||||
LabelsInitializer,
|
||||
)
|
||||
|
||||
@ -50,7 +50,7 @@ def removeind(ins, attr, indices):
|
||||
|
||||
def get_cikwargs(init, distribution):
|
||||
"""Return appropriate key-word arguments for a component initializer."""
|
||||
if isinstance(init, ClassAwareCompInitializer):
|
||||
if isinstance(init, AbstractClassAwareCompInitializer):
|
||||
cikwargs = dict(distribution=distribution)
|
||||
else:
|
||||
distribution = parse_distribution(distribution)
|
||||
@ -69,7 +69,7 @@ class AbstractComponents(torch.nn.Module):
|
||||
@property
|
||||
def components(self):
|
||||
"""Detached Tensor containing the components."""
|
||||
return self._components.detach()
|
||||
return self._components.detach().cpu()
|
||||
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
@ -259,7 +259,7 @@ class ReasoningComponents(AbstractComponents):
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach()
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""ProtoTorch code initializers"""
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
@ -15,6 +16,24 @@ class AbstractComponentsInitializer(ABC):
|
||||
...
|
||||
|
||||
|
||||
class LiteralCompInitializer(AbstractComponentsInitializer):
|
||||
"""'Generate' the provided components.
|
||||
|
||||
Use this to 'generate' pre-initialized components elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, components):
|
||||
self.components = components
|
||||
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return `self.components`."""
|
||||
if not isinstance(self.components, torch.Tensor):
|
||||
wmsg = f"Converting components to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
self.components = torch.Tensor(self.components)
|
||||
return self.components
|
||||
|
||||
|
||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all dimension-aware components initializers."""
|
||||
def __init__(self, shape: Union[Iterable, int]):
|
||||
@ -28,88 +47,6 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||
...
|
||||
|
||||
|
||||
class DataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all data-aware components initializers.
|
||||
|
||||
Components generated by data-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class ClassAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all class-aware components initializers.
|
||||
|
||||
Components generated by class-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def subinit_type(self) -> DataAwareCompInitializer:
|
||||
...
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
transform=self.transform,
|
||||
)
|
||||
samples = initializer.generate(num_components=v)
|
||||
components = torch.cat([components, samples])
|
||||
return components
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class LiteralCompInitializer(DataAwareCompInitializer):
|
||||
"""'Generate' the provided components.
|
||||
|
||||
Use this to 'generate' pre-initialized components from elsewhere.
|
||||
|
||||
"""
|
||||
def generate(self, num_components: int):
|
||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||
components = self.transform(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||
"""Generate zeros corresponding to the components shape."""
|
||||
def generate(self, num_components: int):
|
||||
@ -163,7 +100,46 @@ class RandomNormalCompInitializer(OnesCompInitializer):
|
||||
return components
|
||||
|
||||
|
||||
class SelectionCompInitializer(DataAwareCompInitializer):
|
||||
class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
||||
"""Abstract class for all data-aware components initializers.
|
||||
|
||||
Components generated by data-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
`data` has to be a torch tensor.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data: torch.TensorType,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data = data
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
|
||||
def generate_end_hook(self, samples):
|
||||
drift = torch.rand_like(samples) * self.noise
|
||||
components = self.transform(samples + drift)
|
||||
return components
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, num_components: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""'Generate' the components from the provided data."""
|
||||
def generate(self, num_components: int = 0):
|
||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by uniformly sampling from the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||
@ -172,7 +148,7 @@ class SelectionCompInitializer(DataAwareCompInitializer):
|
||||
return components
|
||||
|
||||
|
||||
class MeanCompInitializer(DataAwareCompInitializer):
|
||||
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Generate components by computing the mean of the provided data."""
|
||||
def generate(self, num_components: int):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
@ -182,14 +158,74 @@ class MeanCompInitializer(DataAwareCompInitializer):
|
||||
return components
|
||||
|
||||
|
||||
class StratifiedSelectionCompInitializer(ClassAwareCompInitializer):
|
||||
class AbstractClassAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||
"""Abstract class for all class-aware components initializers.
|
||||
|
||||
Components generated by class-aware components initializers inherit the shape
|
||||
of the provided data.
|
||||
|
||||
`data` could be a torch Dataset or DataLoader or a list/tuple of data and
|
||||
target tensors.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
data,
|
||||
noise: float = 0.0,
|
||||
transform: callable = torch.nn.Identity()):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
self.noise = noise
|
||||
self.transform = transform
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""'Generate' components from provided data and requested distribution."""
|
||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||
components = self.generate_end_hook(self.data)
|
||||
return components
|
||||
|
||||
|
||||
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||
"""Abstract class for all stratified components initializers."""
|
||||
@property
|
||||
@abstractmethod
|
||||
def subinit_type(self) -> AbstractDataAwareCompInitializer:
|
||||
...
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
distribution = parse_distribution(distribution)
|
||||
components = torch.tensor([])
|
||||
for k, v in distribution.items():
|
||||
stratified_data = self.data[self.targets == k]
|
||||
initializer = self.subinit_type(
|
||||
stratified_data,
|
||||
noise=self.noise,
|
||||
transform=self.transform,
|
||||
)
|
||||
samples = initializer.generate(num_components=v)
|
||||
components = torch.cat([components, samples])
|
||||
return components
|
||||
|
||||
|
||||
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components using stratified sampling from the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
return SelectionCompInitializer
|
||||
|
||||
|
||||
class StratifiedMeanCompInitializer(ClassAwareCompInitializer):
|
||||
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||
"""Generate components at stratified means of the provided data."""
|
||||
@property
|
||||
def subinit_type(self):
|
||||
@ -204,6 +240,38 @@ class AbstractLabelsInitializer(ABC):
|
||||
...
|
||||
|
||||
|
||||
class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
||||
"""'Generate' the provided labels.
|
||||
|
||||
Use this to 'generate' pre-initialized labels elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, labels):
|
||||
self.labels = labels
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||
"""Ignore `distribution` and simply return `self.labels`.
|
||||
|
||||
Convert to long tensor, if necessary.
|
||||
"""
|
||||
labels = self.labels
|
||||
if not isinstance(labels, torch.LongTensor):
|
||||
wmsg = f"Converting labels to {torch.LongTensor}..."
|
||||
warnings.warn(wmsg)
|
||||
labels = torch.LongTensor(labels)
|
||||
return labels
|
||||
|
||||
|
||||
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||
"""'Generate' the labels from a torch Dataset."""
|
||||
def __init__(self, data):
|
||||
self.data, self.targets = parse_data_arg(data)
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||
"""Ignore `num_components` and simply return `self.targets`."""
|
||||
return self.targets
|
||||
|
||||
|
||||
class LabelsInitializer(AbstractLabelsInitializer):
|
||||
"""Generate labels from `distribution`."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
@ -248,6 +316,27 @@ class AbstractReasoningsInitializer(ABC):
|
||||
return generate_end_hook(...)
|
||||
|
||||
|
||||
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""'Generate' the provided reasonings.
|
||||
|
||||
Use this to 'generate' pre-initialized reasonings elsewhere.
|
||||
|
||||
"""
|
||||
def __init__(self, reasonings, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.reasonings = reasonings
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple] = []):
|
||||
"""Ignore `distributuion` and simply return self.reasonings."""
|
||||
reasonings = self.reasonings
|
||||
if not isinstance(reasonings, torch.Tensor):
|
||||
wmsg = f"Converting reasonings to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
reasonings = torch.Tensor(reasonings)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with zeros."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
@ -292,23 +381,28 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
|
||||
|
||||
# Aliases - Components
|
||||
ZCI = ZerosCompInitializer
|
||||
OCI = OnesCompInitializer
|
||||
CACI = ClassAwareCompInitializer
|
||||
DACI = DataAwareCompInitializer
|
||||
FVCI = FillValueCompInitializer
|
||||
LCI = LiteralCompInitializer
|
||||
UCI = UniformCompInitializer
|
||||
MCI = MeanCompInitializer
|
||||
OCI = OnesCompInitializer
|
||||
RNCI = RandomNormalCompInitializer
|
||||
SCI = SelectionCompInitializer
|
||||
MCI = MeanCompInitializer
|
||||
SSCI = StratifiedSelectionCompInitializer
|
||||
SMCI = StratifiedMeanCompInitializer
|
||||
SSCI = StratifiedSelectionCompInitializer
|
||||
UCI = UniformCompInitializer
|
||||
ZCI = ZerosCompInitializer
|
||||
|
||||
# Aliases - Labels
|
||||
DLI = DataAwareLabelsInitializer
|
||||
LI = LabelsInitializer
|
||||
LLI = LiteralLabelsInitializer
|
||||
OHLI = OneHotLabelsInitializer
|
||||
|
||||
# Aliases - Reasonings
|
||||
ZRI = ZerosReasoningsInitializer
|
||||
LRI = LiteralReasoningsInitializer
|
||||
ORI = OnesReasoningsInitializer
|
||||
RRI = RandomReasoningsInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
||||
RRI = RandomReasoningsInitializer
|
||||
ZRI = ZerosReasoningsInitializer
|
||||
|
@ -67,17 +67,19 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
|
||||
elif isinstance(user_distribution, list):
|
||||
return distribution_from_list(user_distribution, clabels)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
msg = f"`distribution` was not understood." \
|
||||
f"You have provided: {user_distribution}."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||
"""Return data and target as torch tensors."""
|
||||
if isinstance(data_arg, Dataset):
|
||||
ds_size = len(data_arg)
|
||||
data_arg = DataLoader(data_arg, batch_size=ds_size)
|
||||
loader = DataLoader(data_arg, batch_size=ds_size)
|
||||
data, targets = next(iter(loader))
|
||||
|
||||
if isinstance(data_arg, DataLoader):
|
||||
elif isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
@ -87,11 +89,11 @@ def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||
assert len(data_arg) == 2
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
wmsg = f"Converting data to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(targets, torch.LongTensor):
|
||||
wmsg = f"Converting targets to {torch.LongTensor}."
|
||||
wmsg = f"Converting targets to {torch.LongTensor}..."
|
||||
warnings.warn(wmsg)
|
||||
targets = torch.LongTensor(targets)
|
||||
return data, targets
|
||||
|
Loading…
Reference in New Issue
Block a user