From ee30d4da5bd58e51c65adb09e7974e688792017f Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Fri, 21 May 2021 16:00:20 +0200 Subject: [PATCH] [BUGFIX] Initializers can handle Dataloaders now --- prototorch/components/initializers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index 42940bf..c8e1059 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -7,17 +7,18 @@ import torch from torch.utils.data import DataLoader, Dataset -def parse_data_arg(data): - if isinstance(data, Dataset): - data, labels = next(iter(DataLoader(data, batch_size=len(data)))) - elif isinstance(data, DataLoader): +def parse_data_arg(data_arg): + if isinstance(data_arg, Dataset): + data_arg = DataLoader(data_arg, batch_size=len(data_arg)) + + if isinstance(data_arg, DataLoader): data = torch.tensor([]) labels = torch.tensor([]) - for x, y in data: + for x, y in data_arg: data = torch.cat([data, x]) labels = torch.cat([labels, y]) else: - data, labels = data + data, labels = data_arg if not isinstance(data, torch.Tensor): wmsg = f"Converting data to {torch.Tensor}." warnings.warn(wmsg)