Update example

This commit is contained in:
Jensun Ravichandran 2020-08-04 11:30:50 +02:00
parent b138277608
commit 3e6aa6a20b

View File

@ -24,9 +24,10 @@ class Model(torch.nn.Module):
"""GLVQ model."""
super().__init__()
self.p1 = Prototypes1D(input_dim=2,
prototypes_per_class=1,
prototypes_per_class=3,
nclasses=3,
prototype_initializer='zeros')
prototype_initializer='stratified_random',
data=[x_train, y_train])
def forward(self, x):
protos = self.p1.prototypes
@ -88,8 +89,8 @@ for epoch in range(70):
torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0]
y_pred = np.argmin(d.detach().numpy(),
axis=1) # assume one prototype per class
w_indices = torch.argmin(d, dim=1)
y_pred = torch.index_select(plabels, 0, w_indices)
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions