Rename rslvq example

This commit is contained in:
Jensun Ravichandran 2021-05-31 17:56:45 +02:00
parent 27eccf44d4
commit 1636c84778
3 changed files with 15 additions and 18 deletions

View File

@ -13,6 +13,9 @@ if __name__ == "__main__":
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args() args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
# Dataset # Dataset
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
@ -20,20 +23,17 @@ if __name__ == "__main__":
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters # Hyperparameters
num_classes = 3
prototypes_per_class = 2
hparams = dict( hparams = dict(
distribution=(num_classes, prototypes_per_class), distribution=[2, 2, 3],
lr=0.05, lr=0.05,
variance=1.0, variance=0.3,
) )
# Initialize the model # Initialize the model
model = pt.models.probabilistic.LikelihoodRatioLVQ( model = pt.models.probabilistic.RSLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
# prototype_initializer=pt.components.UniformInitializer(2), prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
prototype_initializer=pt.components.SMI(train_ds),
) )
print(model) print(model)
@ -45,6 +45,9 @@ if __name__ == "__main__":
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[vis],
terminate_on_nan=True,
weights_summary=None,
# accelerator="ddp",
) )
# Training loop # Training loop

View File

@ -1,6 +1,6 @@
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from . import probabilistic from .probabilistic import LikelihoodRatioLVQ, RSLVQ
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ,
ImageGMLVQ, SiameseGLVQ) ImageGMLVQ, SiameseGLVQ)

View File

@ -2,15 +2,14 @@
import torch import torch
from prototorch.functions.competitions import stratified_sum from prototorch.functions.competitions import stratified_sum
from prototorch.functions.losses import (log_likelihood_ratio_loss, from prototorch.functions.losses import log_likelihood_ratio_loss, robust_soft_loss
robust_soft_loss)
from prototorch.functions.transform import gaussian from prototorch.functions.transform import gaussian
from .glvq import GLVQ from .glvq import GLVQ
class ProbabilisticLVQ(GLVQ): class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=1.0, **kwargs): def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.conditional_distribution = gaussian self.conditional_distribution = gaussian
@ -45,19 +44,14 @@ class ProbabilisticLVQ(GLVQ):
class LikelihoodRatioLVQ(ProbabilisticLVQ): class LikelihoodRatioLVQ(ProbabilisticLVQ):
"""Learning Vector Quantization based on Likelihood Ratios """Learning Vector Quantization based on Likelihood Ratios."""
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss_fn = log_likelihood_ratio_loss self.loss_fn = log_likelihood_ratio_loss
class RSLVQ(ProbabilisticLVQ): class RSLVQ(ProbabilisticLVQ):
"""Learning Vector Quantization based on Likelihood Ratios """Robust Soft Learning Vector Quantization."""
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss_fn = robust_soft_loss self.loss_fn = robust_soft_loss
__all__ = ["LikelihoodRatioLVQ", "RSLVQ"]