Add more models

This commit is contained in:
Jensun Ravichandran 2021-04-29 23:37:22 +02:00
parent db7bb7619f
commit 6dd9b1492c
2 changed files with 80 additions and 26 deletions

View File

@ -48,13 +48,17 @@ To assist in the development process, you may also find it useful to install
- Neural Gas
## Work in Progress
- CBC
- LVQMLN
- GMLVQ
- Limited-Rank GMLVQ
## Planned models
- GMLVQ
- Local-Matrix GMLVQ
- Limited-Rank GMLVQ
- GTLVQ
- RSLVQ
- PLVQ
- LVQMLN
- SILVQ
- KNN

View File

@ -37,32 +37,28 @@ class GLVQ(AbstractPrototypeModel):
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
x = x.view(x.size(0), -1) # flatten
dis = self(x)
plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
loss = mu.sum(dim=0)
self.log("train_loss", loss)
# Compute training accuracy
with torch.no_grad():
preds = wtac(dis, plabels)
# self.train_acc.update(preds.int(), y.int())
self.train_acc(
preds.int(),
y.int()) # FloatTensors are assumed to be class probabilities
self.log(
"acc",
# `.int()` because FloatTensors are assumed to be class probabilities
self.train_acc(preds.int(), y.int())
# Logging
self.log("train_loss", loss)
self.log("acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
logger=True)
# def training_epoch_end(self, outs):
# # Calling `self.train_acc.compute()` is
# # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
# self.log("train_acc_epoch", self.train_acc.compute())
return loss
def predict(self, x):
# model.eval() # ?!
@ -76,8 +72,9 @@ class GLVQ(AbstractPrototypeModel):
class ImageGLVQ(GLVQ):
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by
clamping after updates.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)
@ -89,6 +86,7 @@ class SiameseGLVQ(GLVQ):
GLVQ model that applies an arbitrary transformation on the inputs and the
prototypes before computing the distances between them. The weights in the
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
hparams,
@ -107,14 +105,18 @@ class SiameseGLVQ(GLVQ):
def forward(self, x):
self.sync_backbones()
protos, _ = self.proto_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
@ -122,3 +124,51 @@ class SiameseGLVQ(GLVQ):
d = euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.omega_layer = torch.nn.Linear(self.hparams.input_dim,
self.latent_dim,
bias=False)
def forward(self, x):
protos, _ = self.proto_layer()
latent_x = self.omega_layer(x)
latent_protos = self.omega_layer(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
class LVQMLN(GLVQ):
"""Learning Vector Quantization Multi-Layer Network.
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
on the prototypes before computing the distances between them. This of
course, means that the prototypes no longer live the input space, but
rather in the embedding space.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
def forward(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space."""
with torch.no_grad():
latent_protos, plabels = self.proto_layer()
d = euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()