masks -> ParameterList(requires_grad=False)
This commit is contained in:
parent
78f8b6cc00
commit
adafb49985
@ -250,11 +250,10 @@ class GMLMLVQ(GLVQ):
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
masks = kwargs.get("masks")
|
||||
for i, _mask in enumerate(masks):
|
||||
self.register_buffer(f"_mask_{i}", _mask)
|
||||
self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)]
|
||||
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks])
|
||||
self._masks = ParameterList(
|
||||
[Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
|
||||
)
|
||||
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
|
||||
|
||||
@property
|
||||
def omega_matrices(self):
|
||||
|
Loading…
Reference in New Issue
Block a user