[BUGFIX] Stratified functions work on GPU now
This commit is contained in:
parent
1e0a8392a2
commit
b724a28a6f
@ -39,7 +39,7 @@ def stratify_with(values: torch.Tensor,
|
||||
# skip if stratification is trivial
|
||||
return values
|
||||
batch_size = values.size()[0]
|
||||
winning_values = torch.zeros(num_classes, batch_size)
|
||||
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||
for i, cl in enumerate(clabels):
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
|
Loading…
Reference in New Issue
Block a user