Custom non-gradient training
This commit is contained in:
		@@ -72,6 +72,8 @@ git checkout dev
 | 
				
			|||||||
pip install -e .[all]  # \[all\] if you are using zsh or MacOS
 | 
					pip install -e .[all]  # \[all\] if you are using zsh or MacOS
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Note: Please avoid installing Tensorflow in this environment.**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
To assist in the development process, you may also find it useful to install
 | 
					To assist in the development process, you may also find it useful to install
 | 
				
			||||||
`yapf`, `isort` and `autoflake`. You can install them easily with `pip`.
 | 
					`yapf`, `isort` and `autoflake`. You can install them easily with `pip`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,8 +1,8 @@
 | 
				
			|||||||
from importlib.metadata import PackageNotFoundError, version
 | 
					from importlib.metadata import PackageNotFoundError, version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .cbc import CBC
 | 
					from .cbc import CBC
 | 
				
			||||||
from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
 | 
					from .glvq import (GLVQ, GMLVQ, GRLVQ, GLVQ1, GLVQ21, LVQ1, LVQ21, LVQMLN,
 | 
				
			||||||
                   ImageGMLVQ, SiameseGLVQ)
 | 
					                   ImageGLVQ, ImageGMLVQ, SiameseGLVQ)
 | 
				
			||||||
from .knn import KNN
 | 
					from .knn import KNN
 | 
				
			||||||
from .neural_gas import NeuralGas
 | 
					from .neural_gas import NeuralGas
 | 
				
			||||||
from .vis import *
 | 
					from .vis import *
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,7 +6,8 @@ from prototorch.functions.competitions import wtac
 | 
				
			|||||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
 | 
					from prototorch.functions.distances import (euclidean_distance, omega_distance,
 | 
				
			||||||
                                            sed)
 | 
					                                            sed)
 | 
				
			||||||
from prototorch.functions.helper import get_flat
 | 
					from prototorch.functions.helper import get_flat
 | 
				
			||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
 | 
				
			||||||
 | 
					                                         lvq1_loss, lvq21_loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
 | 
					from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
 | 
				
			||||||
                       SiamesePrototypeModel)
 | 
					                       SiamesePrototypeModel)
 | 
				
			||||||
@@ -33,6 +34,7 @@ class GLVQ(AbstractPrototypeModel):
 | 
				
			|||||||
        # Default Values
 | 
					        # Default Values
 | 
				
			||||||
        self.hparams.setdefault("transfer_function", "identity")
 | 
					        self.hparams.setdefault("transfer_function", "identity")
 | 
				
			||||||
        self.hparams.setdefault("transfer_beta", 10.0)
 | 
					        self.hparams.setdefault("transfer_beta", 10.0)
 | 
				
			||||||
 | 
					        self.hparams.setdefault("lr", 0.01)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.proto_layer = LabeledComponents(
 | 
					        self.proto_layer = LabeledComponents(
 | 
				
			||||||
            distribution=self.hparams.distribution,
 | 
					            distribution=self.hparams.distribution,
 | 
				
			||||||
@@ -52,6 +54,23 @@ class GLVQ(AbstractPrototypeModel):
 | 
				
			|||||||
        dis = self.distance_fn(x, protos)
 | 
					        dis = self.distance_fn(x, protos)
 | 
				
			||||||
        return dis
 | 
					        return dis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def log_acc(self, distances, targets):
 | 
				
			||||||
 | 
					        plabels = self.proto_layer.component_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Compute training accuracy
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            preds = wtac(distances, plabels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.train_acc(preds.int(), targets.int())
 | 
				
			||||||
 | 
					        # `.int()` because FloatTensors are assumed to be class probabilities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log("acc",
 | 
				
			||||||
 | 
					                 self.train_acc,
 | 
				
			||||||
 | 
					                 on_step=False,
 | 
				
			||||||
 | 
					                 on_epoch=True,
 | 
				
			||||||
 | 
					                 prog_bar=True,
 | 
				
			||||||
 | 
					                 logger=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
        x, y = train_batch
 | 
					        x, y = train_batch
 | 
				
			||||||
        dis = self(x)
 | 
					        dis = self(x)
 | 
				
			||||||
@@ -61,21 +80,9 @@ class GLVQ(AbstractPrototypeModel):
 | 
				
			|||||||
                                            beta=self.hparams.transfer_beta)
 | 
					                                            beta=self.hparams.transfer_beta)
 | 
				
			||||||
        loss = batch_loss.sum(dim=0)
 | 
					        loss = batch_loss.sum(dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Compute training accuracy
 | 
					 | 
				
			||||||
        with torch.no_grad():
 | 
					 | 
				
			||||||
            preds = wtac(dis, plabels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.train_acc(preds.int(), y.int())
 | 
					 | 
				
			||||||
        # `.int()` because FloatTensors are assumed to be class probabilities
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Logging
 | 
					        # Logging
 | 
				
			||||||
        self.log("train_loss", loss)
 | 
					        self.log("train_loss", loss)
 | 
				
			||||||
        self.log("acc",
 | 
					        self.log_acc(dis, y)
 | 
				
			||||||
                 self.train_acc,
 | 
					 | 
				
			||||||
                 on_step=False,
 | 
					 | 
				
			||||||
                 on_epoch=True,
 | 
					 | 
				
			||||||
                 prog_bar=True,
 | 
					 | 
				
			||||||
                 logger=True)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return loss
 | 
					        return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -87,6 +94,10 @@ class GLVQ(AbstractPrototypeModel):
 | 
				
			|||||||
            y_pred = wtac(d, plabels)
 | 
					            y_pred = wtac(d, plabels)
 | 
				
			||||||
        return y_pred
 | 
					        return y_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __repr__(self):
 | 
				
			||||||
 | 
					        super_repr = super().__repr__()
 | 
				
			||||||
 | 
					        return f"{super_repr}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
 | 
					class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
 | 
				
			||||||
    """GLVQ in a Siamese setting.
 | 
					    """GLVQ in a Siamese setting.
 | 
				
			||||||
@@ -198,7 +209,77 @@ class LVQMLN(SiamesePrototypeModel, GLVQ):
 | 
				
			|||||||
        return dis
 | 
					        return dis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LVQ1(GLVQ):
 | 
					class NonGradientGLVQ(GLVQ):
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					        self.automatic_optimization = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LVQ1(NonGradientGLVQ):
 | 
				
			||||||
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
 | 
					        protos = self.proto_layer.components
 | 
				
			||||||
 | 
					        plabels = self.proto_layer.component_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x, y = train_batch
 | 
				
			||||||
 | 
					        dis = self(x)
 | 
				
			||||||
 | 
					        # TODO Vectorized implementation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for xi, yi in zip(x, y):
 | 
				
			||||||
 | 
					            d = self(xi.view(1, -1))
 | 
				
			||||||
 | 
					            preds = wtac(d, plabels)
 | 
				
			||||||
 | 
					            w = d.argmin(1)
 | 
				
			||||||
 | 
					            if yi == preds:
 | 
				
			||||||
 | 
					                shift = xi - protos[w]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                shift = protos[w] - xi
 | 
				
			||||||
 | 
					            updated_protos = protos + 0.0
 | 
				
			||||||
 | 
					            updated_protos[w] = protos[w] + (self.hparams.lr * shift)
 | 
				
			||||||
 | 
					            self.proto_layer.load_state_dict({"_components": updated_protos},
 | 
				
			||||||
 | 
					                                             strict=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Logging
 | 
				
			||||||
 | 
					        self.log_acc(dis, y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LVQ21(NonGradientGLVQ):
 | 
				
			||||||
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
 | 
					        protos = self.proto_layer.components
 | 
				
			||||||
 | 
					        plabels = self.proto_layer.component_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        x, y = train_batch
 | 
				
			||||||
 | 
					        dis = self(x)
 | 
				
			||||||
 | 
					        # TODO Vectorized implementation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for xi, yi in zip(x, y):
 | 
				
			||||||
 | 
					            xi = xi.view(1, -1)
 | 
				
			||||||
 | 
					            yi = yi.view(1, )
 | 
				
			||||||
 | 
					            d = self(xi)
 | 
				
			||||||
 | 
					            preds = wtac(d, plabels)
 | 
				
			||||||
 | 
					            (dp, wp), (dn, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
 | 
				
			||||||
 | 
					            shiftp = xi - protos[wp]
 | 
				
			||||||
 | 
					            shiftn = protos[wn] - xi
 | 
				
			||||||
 | 
					            updated_protos = protos + 0.0
 | 
				
			||||||
 | 
					            updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
 | 
				
			||||||
 | 
					            updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
 | 
				
			||||||
 | 
					            self.proto_layer.load_state_dict({"_components": updated_protos},
 | 
				
			||||||
 | 
					                                             strict=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Logging
 | 
				
			||||||
 | 
					        self.log_acc(dis, y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MedianLVQ(NonGradientGLVQ):
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GLVQ1(GLVQ):
 | 
				
			||||||
    """Learning Vector Quantization 1."""
 | 
					    """Learning Vector Quantization 1."""
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
@@ -206,7 +287,7 @@ class LVQ1(GLVQ):
 | 
				
			|||||||
        self.optimizer = torch.optim.SGD
 | 
					        self.optimizer = torch.optim.SGD
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LVQ21(GLVQ):
 | 
					class GLVQ21(GLVQ):
 | 
				
			||||||
    """Learning Vector Quantization 2.1."""
 | 
					    """Learning Vector Quantization 2.1."""
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user