remove accidental LiteralString import
This commit is contained in:
parent
c6f718a1d4
commit
78f8b6cc00
@ -1,7 +1,5 @@
|
|||||||
"""Models based on the GLVQ framework."""
|
"""Models based on the GLVQ framework."""
|
||||||
|
|
||||||
from typing import LiteralString
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from prototorch.core.competitions import wtac
|
from prototorch.core.competitions import wtac
|
||||||
@ -255,7 +253,7 @@ class GMLMLVQ(GLVQ):
|
|||||||
masks = kwargs.get("masks")
|
masks = kwargs.get("masks")
|
||||||
for i, _mask in enumerate(masks):
|
for i, _mask in enumerate(masks):
|
||||||
self.register_buffer(f"_mask_{i}", _mask)
|
self.register_buffer(f"_mask_{i}", _mask)
|
||||||
self._masks = [self.__getattr__(f"_mask_{i}") for i,_ in enumerate(masks)]
|
self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)]
|
||||||
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks])
|
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user