Accept torch datasets to initialize components
This commit is contained in:
parent
47b4b9bcb1
commit
ce3991de94
@ -25,15 +25,16 @@ class Components(torch.nn.Module):
|
|||||||
if initialized_components is not None:
|
if initialized_components is not None:
|
||||||
self._components = Parameter(initialized_components)
|
self._components = Parameter(initialized_components)
|
||||||
if number_of_components is not None or initializer is not None:
|
if number_of_components is not None or initializer is not None:
|
||||||
warnings.warn(
|
wmsg = "Arguments ignored while initializing Components"
|
||||||
"Arguments ignored while initializing Components")
|
warnings.warn(wmsg)
|
||||||
else:
|
else:
|
||||||
self._initialize_components(number_of_components, initializer)
|
self._initialize_components(number_of_components, initializer)
|
||||||
|
|
||||||
def _initialize_components(self, number_of_components, initializer):
|
def _initialize_components(self, number_of_components, initializer):
|
||||||
if not isinstance(initializer, ComponentsInitializer):
|
if not isinstance(initializer, ComponentsInitializer):
|
||||||
emsg = f"`initializer` has to be some kind of `ComponentsInitializer`. " \
|
emsg = f"`initializer` has to be some subtype of " \
|
||||||
f"You provided: {initializer=} instead."
|
f"{ComponentsInitializer}. " \
|
||||||
|
f"You have provided: {initializer=} instead."
|
||||||
raise TypeError(emsg)
|
raise TypeError(emsg)
|
||||||
self._components = Parameter(
|
self._components = Parameter(
|
||||||
initializer.generate(number_of_components))
|
initializer.generate(number_of_components))
|
||||||
|
@ -1,11 +1,30 @@
|
|||||||
"""ProtoTroch Initializers."""
|
"""ProtoTroch Initializers."""
|
||||||
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def parse_init_arg(arg):
|
||||||
|
if isinstance(arg, Dataset):
|
||||||
|
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
|
||||||
|
# data = data.view(len(arg), -1) # flatten
|
||||||
|
else:
|
||||||
|
data, labels = arg
|
||||||
|
if not isinstance(data, torch.Tensor):
|
||||||
|
wmsg = f"Converting data to {torch.Tensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
data = torch.Tensor(data)
|
||||||
|
if not isinstance(labels, torch.Tensor):
|
||||||
|
wmsg = f"Converting labels to {torch.Tensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
labels = torch.Tensor(labels)
|
||||||
|
return data, labels
|
||||||
|
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
class ComponentsInitializer:
|
class ComponentsInitializer(object):
|
||||||
def generate(self, number_of_components):
|
def generate(self, number_of_components):
|
||||||
raise NotImplementedError("Subclasses should implement this!")
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
@ -63,8 +82,9 @@ class MeanInitializer(PositionAwareInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class ClassAwareInitializer(ComponentsInitializer):
|
class ClassAwareInitializer(ComponentsInitializer):
|
||||||
def __init__(self, data, labels):
|
def __init__(self, arg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
data, labels = parse_init_arg(arg)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
@ -73,8 +93,8 @@ class ClassAwareInitializer(ComponentsInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||||
def __init__(self, data, labels):
|
def __init__(self, arg):
|
||||||
super().__init__(data, labels)
|
super().__init__(arg)
|
||||||
|
|
||||||
self.initializers = []
|
self.initializers = []
|
||||||
for clabel in self.clabels:
|
for clabel in self.clabels:
|
||||||
@ -89,8 +109,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||||
def __init__(self, data, labels, noise=None):
|
def __init__(self, arg, *, noise=None):
|
||||||
super().__init__(data, labels)
|
super().__init__(arg)
|
||||||
self.noise = noise
|
self.noise = noise
|
||||||
|
|
||||||
self.initializers = []
|
self.initializers = []
|
||||||
|
Loading…
Reference in New Issue
Block a user