Fix generator bug in stratified_random initializer
This commit is contained in:
		@@ -69,16 +69,13 @@ def stratified_mean(x_train, y_train, prototype_distribution):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@register_initializer
 | 
					@register_initializer
 | 
				
			||||||
def stratified_random(x_train, y_train, prototype_distribution):
 | 
					def stratified_random(x_train, y_train, prototype_distribution):
 | 
				
			||||||
    gen = torch.manual_seed(torch.initial_seed())
 | 
					 | 
				
			||||||
    nprotos = torch.sum(prototype_distribution)
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
    pdim = x_train.shape[1]
 | 
					    pdim = x_train.shape[1]
 | 
				
			||||||
    protos = torch.empty(nprotos, pdim)
 | 
					    protos = torch.empty(nprotos, pdim)
 | 
				
			||||||
    plabels = labels_from(prototype_distribution)
 | 
					    plabels = labels_from(prototype_distribution)
 | 
				
			||||||
    for i, l in enumerate(plabels):
 | 
					    for i, l in enumerate(plabels):
 | 
				
			||||||
        xl = x_train[y_train == l]
 | 
					        xl = x_train[y_train == l]
 | 
				
			||||||
        rand_index = torch.zeros(1).long().random_(0,
 | 
					        rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
 | 
				
			||||||
                                                   xl.shape[1] - 1,
 | 
					 | 
				
			||||||
                                                   generator=gen)
 | 
					 | 
				
			||||||
        random_xl = xl[rand_index]
 | 
					        random_xl = xl[rand_index]
 | 
				
			||||||
        protos[i] = random_xl
 | 
					        protos[i] = random_xl
 | 
				
			||||||
    return protos, plabels
 | 
					    return protos, plabels
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user