From adafb4998593ced0308e5ae093496b20a12fef74 Mon Sep 17 00:00:00 2001 From: julius Date: Tue, 7 Nov 2023 19:17:43 +0100 Subject: [PATCH] masks -> ParameterList(requires_grad=False) --- src/prototorch/models/glvq.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/prototorch/models/glvq.py b/src/prototorch/models/glvq.py index cc187e7..341b4b1 100644 --- a/src/prototorch/models/glvq.py +++ b/src/prototorch/models/glvq.py @@ -250,11 +250,10 @@ class GMLMLVQ(GLVQ): super().__init__(hparams, distance_fn=distance_fn, **kwargs) # Additional parameters - masks = kwargs.get("masks") - for i, _mask in enumerate(masks): - self.register_buffer(f"_mask_{i}", _mask) - self._masks = [self.__getattr__(f"_mask_{i}") for i, _ in enumerate(masks)] - self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks]) + self._masks = ParameterList( + [Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")] + ) + self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks]) @property def omega_matrices(self):