Fix generator bug in stratified_random initializer
This commit is contained in:
parent
2b82830590
commit
3cfbc49254
@ -69,16 +69,13 @@ def stratified_mean(x_train, y_train, prototype_distribution):
|
|||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_random(x_train, y_train, prototype_distribution):
|
def stratified_random(x_train, y_train, prototype_distribution):
|
||||||
gen = torch.manual_seed(torch.initial_seed())
|
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(nprotos, pdim)
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution)
|
||||||
for i, l in enumerate(plabels):
|
for i, l in enumerate(plabels):
|
||||||
xl = x_train[y_train == l]
|
xl = x_train[y_train == l]
|
||||||
rand_index = torch.zeros(1).long().random_(0,
|
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||||
xl.shape[1] - 1,
|
|
||||||
generator=gen)
|
|
||||||
random_xl = xl[rand_index]
|
random_xl = xl[rand_index]
|
||||||
protos[i] = random_xl
|
protos[i] = random_xl
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
Loading…
Reference in New Issue
Block a user