Custom non-gradient training
This commit is contained in:
		| @@ -72,6 +72,8 @@ git checkout dev | ||||
| pip install -e .[all]  # \[all\] if you are using zsh or MacOS | ||||
| ``` | ||||
|  | ||||
| **Note: Please avoid installing Tensorflow in this environment.** | ||||
|  | ||||
| To assist in the development process, you may also find it useful to install | ||||
| `yapf`, `isort` and `autoflake`. You can install them easily with `pip`. | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,8 @@ | ||||
| from importlib.metadata import PackageNotFoundError, version | ||||
|  | ||||
| from .cbc import CBC | ||||
| from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ, | ||||
|                    ImageGMLVQ, SiameseGLVQ) | ||||
| from .glvq import (GLVQ, GMLVQ, GRLVQ, GLVQ1, GLVQ21, LVQ1, LVQ21, LVQMLN, | ||||
|                    ImageGLVQ, ImageGMLVQ, SiameseGLVQ) | ||||
| from .knn import KNN | ||||
| from .neural_gas import NeuralGas | ||||
| from .vis import * | ||||
|   | ||||
| @@ -6,7 +6,8 @@ from prototorch.functions.competitions import wtac | ||||
| from prototorch.functions.distances import (euclidean_distance, omega_distance, | ||||
|                                             sed) | ||||
| from prototorch.functions.helper import get_flat | ||||
| from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss | ||||
| from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss, | ||||
|                                          lvq1_loss, lvq21_loss) | ||||
|  | ||||
| from .abstract import (AbstractPrototypeModel, PrototypeImageModel, | ||||
|                        SiamesePrototypeModel) | ||||
| @@ -33,6 +34,7 @@ class GLVQ(AbstractPrototypeModel): | ||||
|         # Default Values | ||||
|         self.hparams.setdefault("transfer_function", "identity") | ||||
|         self.hparams.setdefault("transfer_beta", 10.0) | ||||
|         self.hparams.setdefault("lr", 0.01) | ||||
|  | ||||
|         self.proto_layer = LabeledComponents( | ||||
|             distribution=self.hparams.distribution, | ||||
| @@ -52,6 +54,23 @@ class GLVQ(AbstractPrototypeModel): | ||||
|         dis = self.distance_fn(x, protos) | ||||
|         return dis | ||||
|  | ||||
|     def log_acc(self, distances, targets): | ||||
|         plabels = self.proto_layer.component_labels | ||||
|  | ||||
|         # Compute training accuracy | ||||
|         with torch.no_grad(): | ||||
|             preds = wtac(distances, plabels) | ||||
|  | ||||
|         self.train_acc(preds.int(), targets.int()) | ||||
|         # `.int()` because FloatTensors are assumed to be class probabilities | ||||
|  | ||||
|         self.log("acc", | ||||
|                  self.train_acc, | ||||
|                  on_step=False, | ||||
|                  on_epoch=True, | ||||
|                  prog_bar=True, | ||||
|                  logger=True) | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         x, y = train_batch | ||||
|         dis = self(x) | ||||
| @@ -61,21 +80,9 @@ class GLVQ(AbstractPrototypeModel): | ||||
|                                             beta=self.hparams.transfer_beta) | ||||
|         loss = batch_loss.sum(dim=0) | ||||
|  | ||||
|         # Compute training accuracy | ||||
|         with torch.no_grad(): | ||||
|             preds = wtac(dis, plabels) | ||||
|  | ||||
|         self.train_acc(preds.int(), y.int()) | ||||
|         # `.int()` because FloatTensors are assumed to be class probabilities | ||||
|  | ||||
|         # Logging | ||||
|         self.log("train_loss", loss) | ||||
|         self.log("acc", | ||||
|                  self.train_acc, | ||||
|                  on_step=False, | ||||
|                  on_epoch=True, | ||||
|                  prog_bar=True, | ||||
|                  logger=True) | ||||
|         self.log_acc(dis, y) | ||||
|  | ||||
|         return loss | ||||
|  | ||||
| @@ -87,6 +94,10 @@ class GLVQ(AbstractPrototypeModel): | ||||
|             y_pred = wtac(d, plabels) | ||||
|         return y_pred | ||||
|  | ||||
|     def __repr__(self): | ||||
|         super_repr = super().__repr__() | ||||
|         return f"{super_repr}" | ||||
|  | ||||
|  | ||||
| class SiameseGLVQ(SiamesePrototypeModel, GLVQ): | ||||
|     """GLVQ in a Siamese setting. | ||||
| @@ -198,7 +209,77 @@ class LVQMLN(SiamesePrototypeModel, GLVQ): | ||||
|         return dis | ||||
|  | ||||
|  | ||||
| class LVQ1(GLVQ): | ||||
| class NonGradientGLVQ(GLVQ): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.automatic_optimization = False | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class LVQ1(NonGradientGLVQ): | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         protos = self.proto_layer.components | ||||
|         plabels = self.proto_layer.component_labels | ||||
|  | ||||
|         x, y = train_batch | ||||
|         dis = self(x) | ||||
|         # TODO Vectorized implementation | ||||
|  | ||||
|         for xi, yi in zip(x, y): | ||||
|             d = self(xi.view(1, -1)) | ||||
|             preds = wtac(d, plabels) | ||||
|             w = d.argmin(1) | ||||
|             if yi == preds: | ||||
|                 shift = xi - protos[w] | ||||
|             else: | ||||
|                 shift = protos[w] - xi | ||||
|             updated_protos = protos + 0.0 | ||||
|             updated_protos[w] = protos[w] + (self.hparams.lr * shift) | ||||
|             self.proto_layer.load_state_dict({"_components": updated_protos}, | ||||
|                                              strict=False) | ||||
|  | ||||
|         # Logging | ||||
|         self.log_acc(dis, y) | ||||
|  | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class LVQ21(NonGradientGLVQ): | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         protos = self.proto_layer.components | ||||
|         plabels = self.proto_layer.component_labels | ||||
|  | ||||
|         x, y = train_batch | ||||
|         dis = self(x) | ||||
|         # TODO Vectorized implementation | ||||
|  | ||||
|         for xi, yi in zip(x, y): | ||||
|             xi = xi.view(1, -1) | ||||
|             yi = yi.view(1, ) | ||||
|             d = self(xi) | ||||
|             preds = wtac(d, plabels) | ||||
|             (dp, wp), (dn, wn) = _get_dp_dm(d, yi, plabels, with_indices=True) | ||||
|             shiftp = xi - protos[wp] | ||||
|             shiftn = protos[wn] - xi | ||||
|             updated_protos = protos + 0.0 | ||||
|             updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp) | ||||
|             updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn) | ||||
|             self.proto_layer.load_state_dict({"_components": updated_protos}, | ||||
|                                              strict=False) | ||||
|  | ||||
|         # Logging | ||||
|         self.log_acc(dis, y) | ||||
|  | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class MedianLVQ(NonGradientGLVQ): | ||||
|     ... | ||||
|  | ||||
|  | ||||
| class GLVQ1(GLVQ): | ||||
|     """Learning Vector Quantization 1.""" | ||||
|     def __init__(self, hparams, **kwargs): | ||||
|         super().__init__(hparams, **kwargs) | ||||
| @@ -206,7 +287,7 @@ class LVQ1(GLVQ): | ||||
|         self.optimizer = torch.optim.SGD | ||||
|  | ||||
|  | ||||
| class LVQ21(GLVQ): | ||||
| class GLVQ21(GLVQ): | ||||
|     """Learning Vector Quantization 2.1.""" | ||||
|     def __init__(self, hparams, **kwargs): | ||||
|         super().__init__(hparams, **kwargs) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user