Accept torch datasets to initialize components
This commit is contained in:
parent
47b4b9bcb1
commit
ce3991de94
@ -25,15 +25,16 @@ class Components(torch.nn.Module):
|
||||
if initialized_components is not None:
|
||||
self._components = Parameter(initialized_components)
|
||||
if number_of_components is not None or initializer is not None:
|
||||
warnings.warn(
|
||||
"Arguments ignored while initializing Components")
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_components(number_of_components, initializer)
|
||||
|
||||
def _initialize_components(self, number_of_components, initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some kind of `ComponentsInitializer`. " \
|
||||
f"You provided: {initializer=} instead."
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
self._components = Parameter(
|
||||
initializer.generate(number_of_components))
|
||||
|
@ -1,11 +1,30 @@
|
||||
"""ProtoTroch Initializers."""
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def parse_init_arg(arg):
|
||||
if isinstance(arg, Dataset):
|
||||
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
|
||||
# data = data.view(len(arg), -1) # flatten
|
||||
else:
|
||||
data, labels = arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(labels, torch.Tensor):
|
||||
wmsg = f"Converting labels to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
labels = torch.Tensor(labels)
|
||||
return data, labels
|
||||
|
||||
|
||||
# Components
|
||||
class ComponentsInitializer:
|
||||
class ComponentsInitializer(object):
|
||||
def generate(self, number_of_components):
|
||||
raise NotImplementedError("Subclasses should implement this!")
|
||||
|
||||
@ -63,8 +82,9 @@ class MeanInitializer(PositionAwareInitializer):
|
||||
|
||||
|
||||
class ClassAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data, labels):
|
||||
def __init__(self, arg):
|
||||
super().__init__()
|
||||
data, labels = parse_init_arg(arg)
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
|
||||
@ -73,8 +93,8 @@ class ClassAwareInitializer(ComponentsInitializer):
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, labels):
|
||||
super().__init__(data, labels)
|
||||
def __init__(self, arg):
|
||||
super().__init__(arg)
|
||||
|
||||
self.initializers = []
|
||||
for clabel in self.clabels:
|
||||
@ -89,8 +109,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, data, labels, noise=None):
|
||||
super().__init__(data, labels)
|
||||
def __init__(self, arg, *, noise=None):
|
||||
super().__init__(arg)
|
||||
self.noise = noise
|
||||
|
||||
self.initializers = []
|
||||
|
Loading…
Reference in New Issue
Block a user