Update example
This commit is contained in:
parent
b138277608
commit
3e6aa6a20b
@ -24,9 +24,10 @@ class Model(torch.nn.Module):
|
||||
"""GLVQ model."""
|
||||
super().__init__()
|
||||
self.p1 = Prototypes1D(input_dim=2,
|
||||
prototypes_per_class=1,
|
||||
prototypes_per_class=3,
|
||||
nclasses=3,
|
||||
prototype_initializer='zeros')
|
||||
prototype_initializer='stratified_random',
|
||||
data=[x_train, y_train])
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
@ -88,8 +89,8 @@ for epoch in range(70):
|
||||
|
||||
torch_input = torch.Tensor(mesh_input)
|
||||
d = model(torch_input)[0]
|
||||
y_pred = np.argmin(d.detach().numpy(),
|
||||
axis=1) # assume one prototype per class
|
||||
w_indices = torch.argmin(d, dim=1)
|
||||
y_pred = torch.index_select(plabels, 0, w_indices)
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
|
Loading…
Reference in New Issue
Block a user