Update example
This commit is contained in:
parent
b138277608
commit
3e6aa6a20b
@ -24,9 +24,10 @@ class Model(torch.nn.Module):
|
|||||||
"""GLVQ model."""
|
"""GLVQ model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = Prototypes1D(input_dim=2,
|
self.p1 = Prototypes1D(input_dim=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=3,
|
||||||
nclasses=3,
|
nclasses=3,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='stratified_random',
|
||||||
|
data=[x_train, y_train])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos = self.p1.prototypes
|
||||||
@ -88,8 +89,8 @@ for epoch in range(70):
|
|||||||
|
|
||||||
torch_input = torch.Tensor(mesh_input)
|
torch_input = torch.Tensor(mesh_input)
|
||||||
d = model(torch_input)[0]
|
d = model(torch_input)[0]
|
||||||
y_pred = np.argmin(d.detach().numpy(),
|
w_indices = torch.argmin(d, dim=1)
|
||||||
axis=1) # assume one prototype per class
|
y_pred = torch.index_select(plabels, 0, w_indices)
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = y_pred.reshape(xx.shape)
|
||||||
|
|
||||||
# Plot voronoi regions
|
# Plot voronoi regions
|
||||||
|
Loading…
Reference in New Issue
Block a user