Merge branch 'dev' of github.com:si-cim/prototorch_models into dev
This commit is contained in:
		| @@ -1,12 +1,11 @@ | ||||
| """GLVQ example using the MNIST dataset.""" | ||||
|  | ||||
| from prototorch.models import ImageGLVQ | ||||
| from prototorch.models.data import train_on_mnist | ||||
| from pytorch_lightning.utilities.cli import LightningCLI | ||||
|  | ||||
| from mnist import TrainOnMNIST | ||||
|  | ||||
|  | ||||
| class GLVQMNIST(TrainOnMNIST, ImageGLVQ): | ||||
| class GLVQMNIST(train_on_mnist(batch_size=64), ImageGLVQ): | ||||
|     """Model Definition.""" | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -33,11 +33,12 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|  | ||||
|     # Initialize the model | ||||
|     model = pt.models.probabilistic.RSLVQ( | ||||
|     model = pt.models.probabilistic.LikelihoodRatioLVQ( | ||||
|         #model = pt.models.probabilistic.RSLVQ( | ||||
|         hparams, | ||||
|         optimizer=torch.optim.Adam, | ||||
|         prototype_initializer=pt.components.SSI(train_ds, noise=2), | ||||
|         #prototype_initializer=pt.components.UniformInitializer(2), | ||||
|         #prototype_initializer=pt.components.SSI(train_ds, noise=2), | ||||
|         prototype_initializer=pt.components.UniformInitializer(2), | ||||
|     ) | ||||
|  | ||||
|     # Callbacks | ||||
|   | ||||
| @@ -10,15 +10,12 @@ class MNISTDataModule(pl.LightningDataModule): | ||||
|         super().__init__() | ||||
|         self.batch_size = batch_size | ||||
| 
 | ||||
|     # When doing distributed training, Datamodules have two optional arguments for | ||||
|     # granular control over download/prepare/splitting data: | ||||
| 
 | ||||
|     # OPTIONAL, called only on 1 GPU/machine | ||||
|     # Download mnist dataset as side-effect, only called on the first cpu | ||||
|     def prepare_data(self): | ||||
|         MNIST("~/datasets", train=True, download=True) | ||||
|         MNIST("~/datasets", train=False, download=True) | ||||
| 
 | ||||
|     # OPTIONAL, called for every GPU/machine (assigning state is OK) | ||||
|     # called for every GPU/machine (assigning state is OK) | ||||
|     def setup(self, stage=None): | ||||
|         # Transforms | ||||
|         transform = transforms.Compose([ | ||||
| @@ -28,13 +25,17 @@ class MNISTDataModule(pl.LightningDataModule): | ||||
|         if stage in (None, "fit"): | ||||
|             mnist_train = MNIST("~/datasets", train=True, transform=transform) | ||||
|             self.mnist_train, self.mnist_val = random_split( | ||||
|                 mnist_train, [55000, 5000]) | ||||
|                 mnist_train, | ||||
|                 [55000, 5000], | ||||
|             ) | ||||
|         if stage == (None, "test"): | ||||
|             self.mnist_test = MNIST("~/datasets", | ||||
|                                     train=False, | ||||
|                                     transform=transform) | ||||
|             self.mnist_test = MNIST( | ||||
|                 "~/datasets", | ||||
|                 train=False, | ||||
|                 transform=transform, | ||||
|             ) | ||||
| 
 | ||||
|     # Return the dataloader for each split | ||||
|     # Dataloaders | ||||
|     def train_dataloader(self): | ||||
|         mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) | ||||
|         return mnist_train | ||||
| @@ -48,8 +49,11 @@ class MNISTDataModule(pl.LightningDataModule): | ||||
|         return mnist_test | ||||
| 
 | ||||
| 
 | ||||
| class TrainOnMNIST(pl.LightningModule): | ||||
|     datamodule = MNISTDataModule(batch_size=256) | ||||
| def train_on_mnist(batch_size=256) -> type: | ||||
|     class DataClass(pl.LightningModule): | ||||
|         datamodule = MNISTDataModule(batch_size=batch_size) | ||||
| 
 | ||||
|     def prototype_initializer(self, **kwargs): | ||||
|         return pt.components.Zeros((28, 28, 1)) | ||||
|         def prototype_initializer(self, **kwargs): | ||||
|             return pt.components.Zeros((28, 28, 1)) | ||||
| 
 | ||||
|     return DataClass | ||||
| @@ -1,100 +1,26 @@ | ||||
| """Probabilistic GLVQ methods""" | ||||
|  | ||||
| import torch | ||||
| from prototorch.functions.competitions import stratified_sum | ||||
| from prototorch.functions.losses import (log_likelihood_ratio_loss, | ||||
|                                          robust_soft_loss) | ||||
| from prototorch.functions.transform import gaussian | ||||
|  | ||||
| from .glvq import GLVQ | ||||
|  | ||||
|  | ||||
| # HELPER | ||||
| # TODO: Refactor into general files, if useful | ||||
| def probability(distance, variance): | ||||
|     return torch.exp(-(distance * distance) / (2 * variance)) | ||||
|  | ||||
|  | ||||
| def grouped_sum(value: torch.Tensor, | ||||
|                 labels: torch.LongTensor) -> (torch.Tensor, torch.LongTensor): | ||||
|     """Group-wise average for (sparse) grouped tensors | ||||
|  | ||||
|     Args: | ||||
|         value (torch.Tensor): values to average (# samples, latent dimension) | ||||
|         labels (torch.LongTensor): labels for embedding parameters (# samples,) | ||||
|  | ||||
|     Returns: | ||||
|         result (torch.Tensor): (# unique labels, latent dimension) | ||||
|         new_labels (torch.LongTensor): (# unique labels,) | ||||
|  | ||||
|     Examples: | ||||
|         >>> samples = torch.Tensor([ | ||||
|                              [0.15, 0.15, 0.15],    #-> group / class 1 | ||||
|                              [0.2,  0.2,  0.2 ],    #-> group / class 3 | ||||
|                              [0.4,  0.4,  0.4 ],    #-> group / class 3 | ||||
|                              [0.0,  0.0,  0.0 ]     #-> group / class 0 | ||||
|                       ]) | ||||
|         >>> labels = torch.LongTensor([1, 5, 5, 0]) | ||||
|         >>> result, new_labels = groupby_mean(samples, labels) | ||||
|  | ||||
|         >>> result | ||||
|         tensor([[0.0000, 0.0000, 0.0000], | ||||
|                 [0.1500, 0.1500, 0.1500], | ||||
|                 [0.3000, 0.3000, 0.3000]]) | ||||
|  | ||||
|         >>> new_labels | ||||
|         tensor([0, 1, 5]) | ||||
|     """ | ||||
|     uniques = labels.unique(sorted=True).tolist() | ||||
|     labels = labels.tolist() | ||||
|  | ||||
|     key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} | ||||
|     labels = torch.LongTensor(list(map(key_val.get, labels))) | ||||
|  | ||||
|     labels = labels.view(labels.size(0), 1).expand(-1, value.size(1)) | ||||
|  | ||||
|     unique_labels = labels.unique(dim=0) | ||||
|     result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_( | ||||
|         0, labels, value) | ||||
|     return result.T | ||||
|  | ||||
|  | ||||
| def likelihood_loss(probabilities, target, prototype_labels): | ||||
|     uniques = prototype_labels.unique(sorted=True).tolist() | ||||
|     labels = target.tolist() | ||||
|  | ||||
|     key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} | ||||
|     target_indices = torch.LongTensor(list(map(key_val.get, labels))) | ||||
|  | ||||
|     whole_probability = probabilities.sum(dim=1) | ||||
|     correct_probability = probabilities[torch.arange(len(probabilities)), | ||||
|                                         target_indices] | ||||
|     wrong_probability = whole_probability - correct_probability | ||||
|  | ||||
|     likelihood = correct_probability / wrong_probability | ||||
|     log_likelihood = torch.log(likelihood) | ||||
|     return log_likelihood | ||||
|  | ||||
|  | ||||
| def robust_soft_loss(probabilities, target, prototype_labels): | ||||
|     uniques = prototype_labels.unique(sorted=True).tolist() | ||||
|     labels = target.tolist() | ||||
|  | ||||
|     key_val = {key: val for key, val in zip(uniques, range(len(uniques)))} | ||||
|     target_indices = torch.LongTensor(list(map(key_val.get, labels))) | ||||
|  | ||||
|     whole_probability = probabilities.sum(dim=1) | ||||
|     correct_probability = probabilities[torch.arange(len(probabilities)), | ||||
|                                         target_indices] | ||||
|  | ||||
|     likelihood = correct_probability / whole_probability | ||||
|     log_likelihood = torch.log(likelihood) | ||||
|     return log_likelihood | ||||
|  | ||||
|  | ||||
| class LikelihoodRatioLVQ(GLVQ): | ||||
|     """Learning Vector Quantization based on Likelihood Ratios | ||||
|     """ | ||||
|     def __init__(self, hparams, **kwargs): | ||||
| class ProbabilisticLVQ(GLVQ): | ||||
|     def __init__(self, hparams, rejection_confidence=1.0, **kwargs): | ||||
|         super().__init__(hparams, **kwargs) | ||||
|  | ||||
|         self.conditional_distribution = probability | ||||
|         self.conditional_distribution = gaussian | ||||
|         self.rejection_confidence = rejection_confidence | ||||
|  | ||||
|     def predict(self, x): | ||||
|         probabilities = self.forward(x) | ||||
|         confidence, prediction = torch.max(probabilities, dim=1) | ||||
|         prediction[confidence < self.rejection_confidence] = -1 | ||||
|         return prediction | ||||
|  | ||||
|     def forward(self, x): | ||||
|         distances = self._forward(x) | ||||
| @@ -104,7 +30,7 @@ class LikelihoodRatioLVQ(GLVQ): | ||||
|         posterior = conditional * prior | ||||
|  | ||||
|         plabels = torch.LongTensor(self.proto_layer.component_labels) | ||||
|         y_pred = grouped_sum(posterior.T, plabels) | ||||
|         y_pred = stratified_sum(posterior.T, plabels) | ||||
|  | ||||
|         return y_pred | ||||
|  | ||||
| @@ -112,52 +38,26 @@ class LikelihoodRatioLVQ(GLVQ): | ||||
|         X, y = batch | ||||
|         out = self.forward(X) | ||||
|         plabels = self.proto_layer.component_labels | ||||
|         batch_loss = -likelihood_loss(out, y, prototype_labels=plabels) | ||||
|         batch_loss = -self.loss_fn(out, y, plabels) | ||||
|         loss = batch_loss.sum(dim=0) | ||||
|  | ||||
|         return loss | ||||
|  | ||||
|     def predict(self, x): | ||||
|         probabilities = self.forward(x) | ||||
|         confidence, prediction = torch.max(probabilities, dim=1) | ||||
|         prediction[confidence < 0.1] = -1 | ||||
|         return prediction | ||||
|  | ||||
|  | ||||
| class RSLVQ(GLVQ): | ||||
| class LikelihoodRatioLVQ(ProbabilisticLVQ): | ||||
|     """Learning Vector Quantization based on Likelihood Ratios | ||||
|     """ | ||||
|     def __init__(self, hparams, **kwargs): | ||||
|         super().__init__(hparams, **kwargs) | ||||
|  | ||||
|         self.conditional_distribution = probability | ||||
|  | ||||
|     def forward(self, x): | ||||
|         distances = self._forward(x) | ||||
|         conditional = self.conditional_distribution(distances, | ||||
|                                                     self.hparams.variance) | ||||
|         prior = 1.0 / torch.Tensor(self.proto_layer.distribution).sum().item() | ||||
|         posterior = conditional * prior | ||||
|  | ||||
|         plabels = torch.LongTensor(self.proto_layer.component_labels) | ||||
|         y_pred = grouped_sum(posterior.T, plabels) | ||||
|  | ||||
|         return y_pred | ||||
|  | ||||
|     def training_step(self, batch, batch_idx, optimizer_idx=None): | ||||
|         X, y = batch | ||||
|         out = self.forward(X) | ||||
|         plabels = self.proto_layer.component_labels | ||||
|         batch_loss = -robust_soft_loss(out, y, prototype_labels=plabels) | ||||
|         loss = batch_loss.sum(dim=0) | ||||
|  | ||||
|         return loss | ||||
|  | ||||
|     def predict(self, x): | ||||
|         probabilities = self.forward(x) | ||||
|         confidence, prediction = torch.max(probabilities, dim=1) | ||||
|         #prediction[confidence < 0.1] = -1 | ||||
|         return prediction | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.loss_fn = log_likelihood_ratio_loss | ||||
|  | ||||
|  | ||||
| __all__ = ["LikelihoodRatioLVQ", "probability", "grouped_sum"] | ||||
| class RSLVQ(ProbabilisticLVQ): | ||||
|     """Learning Vector Quantization based on Likelihood Ratios | ||||
|     """ | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.loss_fn = robust_soft_loss | ||||
|  | ||||
|  | ||||
| __all__ = ["LikelihoodRatioLVQ", "RSLVQ"] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user