diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index 7709e93..cc187e7 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -1,7 +1,5 @@ """Models based on the GLVQ framework.""" -from typing import LiteralString - import torch from numpy.typing import NDArray from prototorch.core.competitions import wtac @@ -255,7 +253,7 @@ class GMLMLVQ(GLVQ): masks = kwargs.get("masks") for i, _mask in enumerate(masks): self.register_buffer(f"_mask_{i}", _mask) - self._masks = [self.__getattr__(f"_mask_{i}") for i,_ in enumerate(masks)] + self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)] self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks]) @property