From ce3991de9417638824fe8fd9b67c5173f9411b03 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:19:22 +0200 Subject: [PATCH] Accept torch datasets to initialize components --- prototorch/components/components.py | 9 ++++---- prototorch/components/initializers.py | 32 ++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/prototorch/components/components.py b/prototorch/components/components.py index 1c81ac7..555fb6f 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -25,15 +25,16 @@ class Components(torch.nn.Module): if initialized_components is not None: self._components = Parameter(initialized_components) if number_of_components is not None or initializer is not None: - warnings.warn( - "Arguments ignored while initializing Components") + wmsg = "Arguments ignored while initializing Components" + warnings.warn(wmsg) else: self._initialize_components(number_of_components, initializer) def _initialize_components(self, number_of_components, initializer): if not isinstance(initializer, ComponentsInitializer): - emsg = f"`initializer` has to be some kind of `ComponentsInitializer`. " \ - f"You provided: {initializer=} instead." + emsg = f"`initializer` has to be some subtype of " \ + f"{ComponentsInitializer}. " \ + f"You have provided: {initializer=} instead." raise TypeError(emsg) self._components = Parameter( initializer.generate(number_of_components)) diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index e739976..d20d0c0 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -1,11 +1,30 @@ """ProtoTroch Initializers.""" +import warnings from collections.abc import Iterable import torch +from torch.utils.data import DataLoader, Dataset + + +def parse_init_arg(arg): + if isinstance(arg, Dataset): + data, labels = next(iter(DataLoader(arg, batch_size=len(arg)))) + # data = data.view(len(arg), -1) # flatten + else: + data, labels = arg + if not isinstance(data, torch.Tensor): + wmsg = f"Converting data to {torch.Tensor}." + warnings.warn(wmsg) + data = torch.Tensor(data) + if not isinstance(labels, torch.Tensor): + wmsg = f"Converting labels to {torch.Tensor}." + warnings.warn(wmsg) + labels = torch.Tensor(labels) + return data, labels # Components -class ComponentsInitializer: +class ComponentsInitializer(object): def generate(self, number_of_components): raise NotImplementedError("Subclasses should implement this!") @@ -63,8 +82,9 @@ class MeanInitializer(PositionAwareInitializer): class ClassAwareInitializer(ComponentsInitializer): - def __init__(self, data, labels): + def __init__(self, arg): super().__init__() + data, labels = parse_init_arg(arg) self.data = data self.labels = labels @@ -73,8 +93,8 @@ class ClassAwareInitializer(ComponentsInitializer): class StratifiedMeanInitializer(ClassAwareInitializer): - def __init__(self, data, labels): - super().__init__(data, labels) + def __init__(self, arg): + super().__init__(arg) self.initializers = [] for clabel in self.clabels: @@ -89,8 +109,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer): class StratifiedSelectionInitializer(ClassAwareInitializer): - def __init__(self, data, labels, noise=None): - super().__init__(data, labels) + def __init__(self, arg, *, noise=None): + super().__init__(arg) self.noise = noise self.initializers = []