[WIP] Update CBC implementation to use SiameseGLVQ
This commit is contained in:
@@ -49,11 +49,21 @@ class GLVQ(AbstractPrototypeModel):
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
|
||||
def forward(self, x):
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.proto_layer.distribution)
|
||||
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_fn(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
with torch.no_grad():
|
||||
plabels = self.proto_layer.component_labels
|
||||
@@ -62,7 +72,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
distances = self(x)
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
return y_pred
|
||||
|
||||
@@ -80,7 +90,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self(x)
|
||||
out = self._forward(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(out, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||
@@ -89,6 +99,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
self.log("train_loss", train_loss)
|
||||
self.log_acc(out, batch[-1], tag="train_acc")
|
||||
return train_loss
|
||||
|
||||
@@ -137,23 +148,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
||||
self.both_path_gradients = both_path_gradients
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
self.backbone.requires_grad_(True)
|
||||
dis = self.distance_fn(latent_x, latent_protos)
|
||||
return dis
|
||||
distances = self.distance_fn(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class GRLVQ(SiamesePrototypeModel, GLVQ):
|
||||
class GRLVQ(SiameseGLVQ):
|
||||
"""Generalized Relevance Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.relevances = torch.nn.parameter.Parameter(
|
||||
torch.ones(self.hparams.input_dim))
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
@@ -163,20 +173,19 @@ class GRLVQ(SiamesePrototypeModel, GLVQ):
|
||||
"""Namespace hook for the visualization callbacks to work."""
|
||||
return x @ torch.diag(self.relevances)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
dis = omega_distance(x, protos, torch.diag(self.relevances))
|
||||
return dis
|
||||
distances = omega_distance(x, protos, torch.diag(self.relevances))
|
||||
return distances
|
||||
|
||||
|
||||
class GMLVQ(SiamesePrototypeModel, GLVQ):
|
||||
class GMLVQ(SiameseGLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
bias=False)
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
@@ -198,16 +207,18 @@ class GMLVQ(SiamesePrototypeModel, GLVQ):
|
||||
plt.colorbar()
|
||||
plt.show(block=True)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
dis = self.distance_fn(latent_x, latent_protos)
|
||||
return dis
|
||||
self.backbone.requires_grad_(True)
|
||||
distances = self.distance_fn(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class LVQMLN(SiamesePrototypeModel, GLVQ):
|
||||
class LVQMLN(SiameseGLVQ):
|
||||
"""Learning Vector Quantization Multi-Layer Network.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
||||
@@ -216,17 +227,11 @@ class LVQMLN(SiamesePrototypeModel, GLVQ):
|
||||
rather in the embedding space.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, backbone=torch.nn.Identity(), **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.backbone = backbone
|
||||
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
latent_protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
dis = self.distance_fn(latent_x, latent_protos)
|
||||
return dis
|
||||
distances = self.distance_fn(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class NonGradientGLVQ(GLVQ):
|
||||
@@ -244,7 +249,7 @@ class LVQ1(NonGradientGLVQ):
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self(x)
|
||||
dis = self._forward(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
@@ -272,7 +277,7 @@ class LVQ21(NonGradientGLVQ):
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self(x)
|
||||
dis = self._forward(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
|
Reference in New Issue
Block a user