From e3f8828da415fd452f403e78e4b733c358d68ae1 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 21 May 2021 11:59:57 +0200 Subject: [PATCH] Accept dataloaders for component initialization --- prototorch/components/initializers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index d8b6529..42940bf 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -10,6 +10,12 @@ from torch.utils.data import DataLoader, Dataset def parse_data_arg(data): if isinstance(data, Dataset): data, labels = next(iter(DataLoader(data, batch_size=len(data)))) + elif isinstance(data, DataLoader): + data = torch.tensor([]) + labels = torch.tensor([]) + for x, y in data: + data = torch.cat([data, x]) + labels = torch.cat([labels, y]) else: data, labels = data if not isinstance(data, torch.Tensor):