Accept torch datasets to initialize components

This commit is contained in:
Jensun Ravichandran 2021-05-07 15:19:22 +02:00
parent 47b4b9bcb1
commit ce3991de94
2 changed files with 31 additions and 10 deletions

View File

@ -25,15 +25,16 @@ class Components(torch.nn.Module):
if initialized_components is not None:
self._components = Parameter(initialized_components)
if number_of_components is not None or initializer is not None:
warnings.warn(
"Arguments ignored while initializing Components")
wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg)
else:
self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer):
if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some kind of `ComponentsInitializer`. " \
f"You provided: {initializer=} instead."
emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
self._components = Parameter(
initializer.generate(number_of_components))

View File

@ -1,11 +1,30 @@
"""ProtoTroch Initializers."""
import warnings
from collections.abc import Iterable
import torch
from torch.utils.data import DataLoader, Dataset
def parse_init_arg(arg):
if isinstance(arg, Dataset):
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
# data = data.view(len(arg), -1) # flatten
else:
data, labels = arg
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg)
data = torch.Tensor(data)
if not isinstance(labels, torch.Tensor):
wmsg = f"Converting labels to {torch.Tensor}."
warnings.warn(wmsg)
labels = torch.Tensor(labels)
return data, labels
# Components
class ComponentsInitializer:
class ComponentsInitializer(object):
def generate(self, number_of_components):
raise NotImplementedError("Subclasses should implement this!")
@ -63,8 +82,9 @@ class MeanInitializer(PositionAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, data, labels):
def __init__(self, arg):
super().__init__()
data, labels = parse_init_arg(arg)
self.data = data
self.labels = labels
@ -73,8 +93,8 @@ class ClassAwareInitializer(ComponentsInitializer):
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, data, labels):
super().__init__(data, labels)
def __init__(self, arg):
super().__init__(arg)
self.initializers = []
for clabel in self.clabels:
@ -89,8 +109,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, data, labels, noise=None):
super().__init__(data, labels)
def __init__(self, arg, *, noise=None):
super().__init__(arg)
self.noise = noise
self.initializers = []