masks -> ParameterList(requires_grad=False)

This commit is contained in:
julius 2023-11-07 19:17:43 +01:00
parent 78f8b6cc00
commit adafb49985
Signed by untrusted user who does not match committer: julius
GPG Key ID: 8AA3791362A8084A

View File

@ -250,11 +250,10 @@ class GMLMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
masks = kwargs.get("masks")
for i, _mask in enumerate(masks):
self.register_buffer(f"_mask_{i}", _mask)
self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)]
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks])
self._masks = ParameterList(
[Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
)
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
@property
def omega_matrices(self):