masks -> ParameterList(requires_grad=False)
This commit is contained in:
parent
78f8b6cc00
commit
adafb49985
@ -250,11 +250,10 @@ class GMLMLVQ(GLVQ):
|
|||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
masks = kwargs.get("masks")
|
self._masks = ParameterList(
|
||||||
for i, _mask in enumerate(masks):
|
[Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
|
||||||
self.register_buffer(f"_mask_{i}", _mask)
|
)
|
||||||
self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)]
|
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
|
||||||
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrices(self):
|
def omega_matrices(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user