Accept dataloaders for component initialization
This commit is contained in:
parent
30adbf705c
commit
e3f8828da4
@ -10,6 +10,12 @@ from torch.utils.data import DataLoader, Dataset
|
||||
def parse_data_arg(data):
|
||||
if isinstance(data, Dataset):
|
||||
data, labels = next(iter(DataLoader(data, batch_size=len(data))))
|
||||
elif isinstance(data, DataLoader):
|
||||
data = torch.tensor([])
|
||||
labels = torch.tensor([])
|
||||
for x, y in data:
|
||||
data = torch.cat([data, x])
|
||||
labels = torch.cat([labels, y])
|
||||
else:
|
||||
data, labels = data
|
||||
if not isinstance(data, torch.Tensor):
|
||||
|
Loading…
Reference in New Issue
Block a user