diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index cc187e7..341b4b1 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -250,11 +250,10 @@ class GMLMLVQ(GLVQ): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - masks = kwargs.get("masks") - for i, _mask in enumerate(masks): - self.register_buffer(f"_mask_{i}", _mask) - self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)] - self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks]) + self._masks = ParameterList( + [Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")] + ) + self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks]) @property def omega_matrices(self):