[BUGFIX] Use _forward in LVQ1 and LVQ21

This commit is contained in:
Jensun Ravichandran 2021-05-25 17:43:37 +02:00
parent 2cc11ae2e3
commit 139109804f

View File

@ -281,7 +281,7 @@ class LVQ1(NonGradientGLVQ):
# TODO Vectorized implementation # TODO Vectorized implementation
for xi, yi in zip(x, y): for xi, yi in zip(x, y):
d = self(xi.view(1, -1)) d = self._forward(xi.view(1, -1))
preds = wtac(d, plabels) preds = wtac(d, plabels)
w = d.argmin(1) w = d.argmin(1)
if yi == preds: if yi == preds:
@ -312,7 +312,7 @@ class LVQ21(NonGradientGLVQ):
for xi, yi in zip(x, y): for xi, yi in zip(x, y):
xi = xi.view(1, -1) xi = xi.view(1, -1)
yi = yi.view(1, ) yi = yi.view(1, )
d = self(xi) d = self._forward(xi)
(_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True) (_, wp), (_, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp] shiftp = xi - protos[wp]
shiftn = protos[wn] - xi shiftn = protos[wn] - xi