[BUGFIX] Stratified functions work on GPU now

This commit is contained in:
Alexander Engelsberger 2021-06-03 13:19:26 +02:00
parent 1e0a8392a2
commit b724a28a6f

View File

@ -39,7 +39,7 @@ def stratify_with(values: torch.Tensor,
# skip if stratification is trivial # skip if stratification is trivial
return values return values
batch_size = values.size()[0] batch_size = values.size()[0]
winning_values = torch.zeros(num_classes, batch_size) winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
filler = torch.full_like(values.T, fill_value=fill_value) filler = torch.full_like(values.T, fill_value=fill_value)
for i, cl in enumerate(clabels): for i, cl in enumerate(clabels):
matcher = torch.eq(labels.unsqueeze(dim=1), cl) matcher = torch.eq(labels.unsqueeze(dim=1), cl)