[BUGFIX] Use _forward
in LVQ1 and LVQ21
This commit is contained in:
parent
2cc11ae2e3
commit
139109804f
@ -281,7 +281,7 @@ class LVQ1(NonGradientGLVQ):
|
|||||||
# TODO Vectorized implementation
|
# TODO Vectorized implementation
|
||||||
|
|
||||||
for xi, yi in zip(x, y):
|
for xi, yi in zip(x, y):
|
||||||
d = self(xi.view(1, -1))
|
d = self._forward(xi.view(1, -1))
|
||||||
preds = wtac(d, plabels)
|
preds = wtac(d, plabels)
|
||||||
w = d.argmin(1)
|
w = d.argmin(1)
|
||||||
if yi == preds:
|
if yi == preds:
|
||||||
@ -312,7 +312,7 @@ class LVQ21(NonGradientGLVQ):
|
|||||||
for xi, yi in zip(x, y):
|
for xi, yi in zip(x, y):
|
||||||
xi = xi.view(1, -1)
|
xi = xi.view(1, -1)
|
||||||
yi = yi.view(1, )
|
yi = yi.view(1, )
|
||||||
d = self(xi)
|
d = self._forward(xi)
|
||||||
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
|
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
|
||||||
shiftp = xi - protos[wp]
|
shiftp = xi - protos[wp]
|
||||||
shiftn = protos[wn] - xi
|
shiftn = protos[wn] - xi
|
||||||
|
Loading…
Reference in New Issue
Block a user