[BUGFIX] Use _forward
in LVQ1 and LVQ21
This commit is contained in:
parent
2cc11ae2e3
commit
139109804f
@ -281,7 +281,7 @@ class LVQ1(NonGradientGLVQ):
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
d = self(xi.view(1, -1))
|
||||
d = self._forward(xi.view(1, -1))
|
||||
preds = wtac(d, plabels)
|
||||
w = d.argmin(1)
|
||||
if yi == preds:
|
||||
@ -312,7 +312,7 @@ class LVQ21(NonGradientGLVQ):
|
||||
for xi, yi in zip(x, y):
|
||||
xi = xi.view(1, -1)
|
||||
yi = yi.view(1, )
|
||||
d = self(xi)
|
||||
d = self._forward(xi)
|
||||
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
|
||||
shiftp = xi - protos[wp]
|
||||
shiftn = protos[wn] - xi
|
||||
|
Loading…
Reference in New Issue
Block a user