Compare commits
	
		
			13 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					aeb6417c28 | ||
| 
						 | 
					cb7fb91c95 | ||
| 
						 | 
					823b05e390 | ||
| 
						 | 
					f8ad1d83eb | ||
| 
						 | 
					23a3683860 | ||
| 
						 | 
					4be9fb81eb | ||
| 
						 | 
					9d38123114 | ||
| 
						 | 
					0f9f24e36a | ||
| 
						 | 
					09e3ef1d0e | ||
| 
						 | 
					7b9b767113 | ||
| 
						 | 
					f56ec44afe | ||
| 
						 | 
					67a20124e8 | ||
| 
						 | 
					72af03b991 | 
							
								
								
									
										38
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
---
 | 
			
		||||
name: Bug report
 | 
			
		||||
about: Create a report to help us improve
 | 
			
		||||
title: ''
 | 
			
		||||
labels: ''
 | 
			
		||||
assignees: ''
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
**Describe the bug**
 | 
			
		||||
A clear and concise description of what the bug is.
 | 
			
		||||
 | 
			
		||||
**Steps to reproduce the behavior**
 | 
			
		||||
1. ...
 | 
			
		||||
2. Run script '...' or this snippet:
 | 
			
		||||
```python
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
...
 | 
			
		||||
```
 | 
			
		||||
3. See errors
 | 
			
		||||
 | 
			
		||||
**Expected behavior**
 | 
			
		||||
A clear and concise description of what you expected to happen.
 | 
			
		||||
 | 
			
		||||
**Observed behavior**
 | 
			
		||||
A clear and concise description of what actually happened.
 | 
			
		||||
 | 
			
		||||
**Screenshots**
 | 
			
		||||
If applicable, add screenshots to help explain your problem.
 | 
			
		||||
 | 
			
		||||
**System and version information**
 | 
			
		||||
- OS: [e.g. Ubuntu 20.10]
 | 
			
		||||
- ProtoTorch Version: [e.g. 0.4.0]
 | 
			
		||||
- Python Version: [e.g. 3.9.5]
 | 
			
		||||
 | 
			
		||||
**Additional context**
 | 
			
		||||
Add any other context about the problem here.
 | 
			
		||||
							
								
								
									
										20
									
								
								.github/ISSUE_TEMPLATE/feature_request.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								.github/ISSUE_TEMPLATE/feature_request.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
---
 | 
			
		||||
name: Feature request
 | 
			
		||||
about: Suggest an idea for this project
 | 
			
		||||
title: ''
 | 
			
		||||
labels: ''
 | 
			
		||||
assignees: ''
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
**Is your feature request related to a problem? Please describe.**
 | 
			
		||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
 | 
			
		||||
 | 
			
		||||
**Describe the solution you'd like**
 | 
			
		||||
A clear and concise description of what you want to happen.
 | 
			
		||||
 | 
			
		||||
**Describe alternatives you've considered**
 | 
			
		||||
A clear and concise description of any alternative solutions or features you've considered.
 | 
			
		||||
 | 
			
		||||
**Additional context**
 | 
			
		||||
Add any other context or screenshots about the feature request here.
 | 
			
		||||
@@ -36,6 +36,7 @@ be available for use in your Python environment as `prototorch.models`.
 | 
			
		||||
- Soft Learning Vector Quantization (SLVQ)
 | 
			
		||||
- Robust Soft Learning Vector Quantization (RSLVQ)
 | 
			
		||||
- Probabilistic Learning Vector Quantization (PLVQ)
 | 
			
		||||
- Median-LVQ
 | 
			
		||||
 | 
			
		||||
### Other
 | 
			
		||||
 | 
			
		||||
@@ -51,7 +52,6 @@ be available for use in your Python environment as `prototorch.models`.
 | 
			
		||||
 | 
			
		||||
## Planned models
 | 
			
		||||
 | 
			
		||||
- Median-LVQ
 | 
			
		||||
- Generalized Tangent Learning Vector Quantization (GTLVQ)
 | 
			
		||||
- Self-Incremental Learning Vector Quantization (SILVQ)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										81
									
								
								examples/binnam_tecator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								examples/binnam_tecator.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
			
		||||
"""Neural Additive Model (NAM) example for binary classification."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from matplotlib import pyplot as plt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Dataset
 | 
			
		||||
    train_ds = pt.datasets.Tecator("~/datasets")
 | 
			
		||||
 | 
			
		||||
    # Dataloaders
 | 
			
		||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
 | 
			
		||||
 | 
			
		||||
    # Hyperparameters
 | 
			
		||||
    hparams = dict(lr=0.1)
 | 
			
		||||
 | 
			
		||||
    # Define the feature extractor
 | 
			
		||||
    class FE(torch.nn.Module):
 | 
			
		||||
        def __init__(self):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.modules_list = torch.nn.ModuleList([
 | 
			
		||||
                torch.nn.Linear(1, 3),
 | 
			
		||||
                torch.nn.Sigmoid(),
 | 
			
		||||
                torch.nn.Linear(3, 1),
 | 
			
		||||
                torch.nn.Sigmoid(),
 | 
			
		||||
            ])
 | 
			
		||||
 | 
			
		||||
        def forward(self, x):
 | 
			
		||||
            for m in self.modules_list:
 | 
			
		||||
                x = m(x)
 | 
			
		||||
            return x
 | 
			
		||||
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.BinaryNAM(
 | 
			
		||||
        hparams,
 | 
			
		||||
        extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Compute intermediate input and output sizes
 | 
			
		||||
    model.example_input_array = torch.zeros(4, 100)
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_loss",
 | 
			
		||||
        min_delta=0.001,
 | 
			
		||||
        patience=20,
 | 
			
		||||
        mode="min",
 | 
			
		||||
        verbose=True,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            es,
 | 
			
		||||
        ],
 | 
			
		||||
        terminate_on_nan=True,
 | 
			
		||||
        weights_summary=None,
 | 
			
		||||
        accelerator="ddp",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
    trainer.fit(model, train_loader)
 | 
			
		||||
 | 
			
		||||
    # Visualize extractor shape functions
 | 
			
		||||
    fig, axes = plt.subplots(10, 10)
 | 
			
		||||
    for i, ax in enumerate(axes.flat):
 | 
			
		||||
        x = torch.linspace(-2, 2, 100)  # TODO use min/max from data
 | 
			
		||||
        y = model.extractors[i](x.view(100, 1)).squeeze().detach()
 | 
			
		||||
        ax.plot(x, y)
 | 
			
		||||
        ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
 | 
			
		||||
    plt.show()
 | 
			
		||||
							
								
								
									
										86
									
								
								examples/binnam_xor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								examples/binnam_xor.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
"""Neural Additive Model (NAM) example for binary classification."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from matplotlib import pyplot as plt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Dataset
 | 
			
		||||
    train_ds = pt.datasets.XOR()
 | 
			
		||||
 | 
			
		||||
    # Dataloaders
 | 
			
		||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
 | 
			
		||||
 | 
			
		||||
    # Hyperparameters
 | 
			
		||||
    hparams = dict(lr=0.001)
 | 
			
		||||
 | 
			
		||||
    # Define the feature extractor
 | 
			
		||||
    class FE(torch.nn.Module):
 | 
			
		||||
        def __init__(self, hidden_size=10):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.modules_list = torch.nn.ModuleList([
 | 
			
		||||
                torch.nn.Linear(1, hidden_size),
 | 
			
		||||
                torch.nn.ReLU(),
 | 
			
		||||
                torch.nn.Linear(hidden_size, 1),
 | 
			
		||||
                torch.nn.ReLU(),
 | 
			
		||||
            ])
 | 
			
		||||
 | 
			
		||||
        def forward(self, x):
 | 
			
		||||
            for m in self.modules_list:
 | 
			
		||||
                x = m(x)
 | 
			
		||||
            return x
 | 
			
		||||
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.BinaryNAM(
 | 
			
		||||
        hparams,
 | 
			
		||||
        extractors=torch.nn.ModuleList([FE(20) for _ in range(2)]),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Compute intermediate input and output sizes
 | 
			
		||||
    model.example_input_array = torch.zeros(4, 2)
 | 
			
		||||
 | 
			
		||||
    # Summary
 | 
			
		||||
    print(model)
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    vis = pt.models.Vis2D(data=train_ds)
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_loss",
 | 
			
		||||
        min_delta=0.001,
 | 
			
		||||
        patience=50,
 | 
			
		||||
        mode="min",
 | 
			
		||||
        verbose=False,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            es,
 | 
			
		||||
        ],
 | 
			
		||||
        terminate_on_nan=True,
 | 
			
		||||
        weights_summary="full",
 | 
			
		||||
        accelerator="ddp",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
    trainer.fit(model, train_loader)
 | 
			
		||||
 | 
			
		||||
    # Visualize extractor shape functions
 | 
			
		||||
    fig, axes = plt.subplots(2)
 | 
			
		||||
    for i, ax in enumerate(axes.flat):
 | 
			
		||||
        x = torch.linspace(0, 1, 100)  # TODO use min/max from data
 | 
			
		||||
        y = model.extractors[i](x.view(100, 1)).squeeze().detach()
 | 
			
		||||
        ax.plot(x, y)
 | 
			
		||||
        ax.set(title=f"Feature {i + 1}")
 | 
			
		||||
    plt.show()
 | 
			
		||||
@@ -1,12 +1,11 @@
 | 
			
		||||
"""GMLVQ example using the MNIST dataset."""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch_lightning.utilities.cli import LightningCLI
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import torch
 | 
			
		||||
from prototorch.models import ImageGMLVQ
 | 
			
		||||
from prototorch.models.abstract import PrototypeModel
 | 
			
		||||
from prototorch.models.data import MNISTDataModule
 | 
			
		||||
from pytorch_lightning.utilities.cli import LightningCLI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExperimentClass(ImageGMLVQ):
 | 
			
		||||
 
 | 
			
		||||
@@ -66,7 +66,7 @@ if __name__ == "__main__":
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            # es, # FIXME
 | 
			
		||||
            es,
 | 
			
		||||
            pruning,
 | 
			
		||||
        ],
 | 
			
		||||
        terminate_on_nan=True,
 | 
			
		||||
 
 | 
			
		||||
@@ -2,12 +2,11 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from sklearn.datasets import load_iris
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										52
									
								
								examples/median_lvq_iris.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								examples/median_lvq_iris.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
"""Median-LVQ example using the Iris dataset."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Dataset
 | 
			
		||||
    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
			
		||||
 | 
			
		||||
    # Dataloaders
 | 
			
		||||
    train_loader = torch.utils.data.DataLoader(
 | 
			
		||||
        train_ds,
 | 
			
		||||
        batch_size=len(train_ds),  # MedianLVQ cannot handle mini-batches
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.MedianLVQ(
 | 
			
		||||
        hparams=dict(distribution=(3, 2), lr=0.01),
 | 
			
		||||
        prototypes_initializer=pt.initializers.SSCI(train_ds),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Compute intermediate input and output sizes
 | 
			
		||||
    model.example_input_array = torch.zeros(4, 2)
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_acc",
 | 
			
		||||
        min_delta=0.01,
 | 
			
		||||
        patience=5,
 | 
			
		||||
        mode="max",
 | 
			
		||||
        verbose=True,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[vis, es],
 | 
			
		||||
        weights_summary="full",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
    trainer.fit(model, train_loader)
 | 
			
		||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    # Setup trainer for GNG
 | 
			
		||||
    trainer = pl.Trainer(
 | 
			
		||||
        max_epochs=200,
 | 
			
		||||
        max_epochs=100,
 | 
			
		||||
        callbacks=[es],
 | 
			
		||||
        weights_summary=None,
 | 
			
		||||
    )
 | 
			
		||||
@@ -71,11 +71,30 @@ if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
			
		||||
    pruning = pt.models.PruneLoserPrototypes(
 | 
			
		||||
        threshold=0.02,
 | 
			
		||||
        idle_epochs=2,
 | 
			
		||||
        prune_quota_per_epoch=5,
 | 
			
		||||
        frequency=1,
 | 
			
		||||
        verbose=True,
 | 
			
		||||
    )
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_loss",
 | 
			
		||||
        min_delta=0.001,
 | 
			
		||||
        patience=10,
 | 
			
		||||
        mode="min",
 | 
			
		||||
        verbose=True,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[vis],
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            pruning,
 | 
			
		||||
            es,
 | 
			
		||||
        ],
 | 
			
		||||
        weights_summary="full",
 | 
			
		||||
        accelerator="ddp",
 | 
			
		||||
    )
 | 
			
		||||
 
 | 
			
		||||
@@ -19,6 +19,7 @@ from .glvq import (
 | 
			
		||||
)
 | 
			
		||||
from .knn import KNN
 | 
			
		||||
from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
			
		||||
from .nam import BinaryNAM
 | 
			
		||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
 | 
			
		||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
 | 
			
		||||
from .vis import *
 | 
			
		||||
 
 | 
			
		||||
@@ -14,20 +14,8 @@ from ..core.pooling import stratified_min_pooling
 | 
			
		||||
from ..nn.wrappers import LambdaLayer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProtoTorchMixin(object):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProtoTorchBolt(pl.LightningModule):
 | 
			
		||||
    """All ProtoTorch models are ProtoTorch Bolts."""
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        surep = super().__repr__()
 | 
			
		||||
        indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
 | 
			
		||||
        wrapped = f"ProtoTorch Bolt(\n{indented})"
 | 
			
		||||
        return wrapped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
    def __init__(self, hparams, **kwargs):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
@@ -42,22 +30,6 @@ class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
        self.lr_scheduler = kwargs.get("lr_scheduler", None)
 | 
			
		||||
        self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
 | 
			
		||||
 | 
			
		||||
        distance_fn = kwargs.get("distance_fn", euclidean_distance)
 | 
			
		||||
        self.distance_layer = LambdaLayer(distance_fn)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_prototypes(self):
 | 
			
		||||
        return len(self.proto_layer.components)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def prototypes(self):
 | 
			
		||||
        return self.proto_layer.components.detach().cpu()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def components(self):
 | 
			
		||||
        """Only an alias for the prototypes."""
 | 
			
		||||
        return self.prototypes
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
 | 
			
		||||
        if self.lr_scheduler is not None:
 | 
			
		||||
@@ -73,7 +45,34 @@ class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
 | 
			
		||||
    @final
 | 
			
		||||
    def reconfigure_optimizers(self):
 | 
			
		||||
        self.trainer.accelerator_backend.setup_optimizers(self.trainer)
 | 
			
		||||
        self.trainer.accelerator.setup_optimizers(self.trainer)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        surep = super().__repr__()
 | 
			
		||||
        indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
 | 
			
		||||
        wrapped = f"ProtoTorch Bolt(\n{indented})"
 | 
			
		||||
        return wrapped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
    def __init__(self, hparams, **kwargs):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        distance_fn = kwargs.get("distance_fn", euclidean_distance)
 | 
			
		||||
        self.distance_layer = LambdaLayer(distance_fn)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def num_prototypes(self):
 | 
			
		||||
        return len(self.proto_layer.components)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def prototypes(self):
 | 
			
		||||
        return self.proto_layer.components.detach().cpu()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def components(self):
 | 
			
		||||
        """Only an alias for the prototypes."""
 | 
			
		||||
        return self.prototypes
 | 
			
		||||
 | 
			
		||||
    def add_prototypes(self, *args, **kwargs):
 | 
			
		||||
        self.proto_layer.add_components(*args, **kwargs)
 | 
			
		||||
@@ -167,6 +166,11 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
			
		||||
                 logger=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProtoTorchMixin(object):
 | 
			
		||||
    """All mixins are ProtoTorchMixins."""
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NonGradientMixin(ProtoTorchMixin):
 | 
			
		||||
    """Mixin for custom non-gradient optimization."""
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
 
 | 
			
		||||
@@ -48,7 +48,7 @@ class CBC(SiameseGLVQ):
 | 
			
		||||
        y_pred = self(x)
 | 
			
		||||
        num_classes = self.num_classes
 | 
			
		||||
        y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
 | 
			
		||||
        loss = self.loss(y_pred, y_true).mean(dim=0)
 | 
			
		||||
        loss = self.loss(y_pred, y_true).mean()
 | 
			
		||||
        return y_pred, loss
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
 
 | 
			
		||||
@@ -5,13 +5,12 @@ Mainly used for PytorchLightningCLI configurations.
 | 
			
		||||
"""
 | 
			
		||||
from typing import Any, Optional, Type
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
from torch.utils.data import DataLoader, Dataset, random_split
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
from torchvision.datasets import MNIST
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# MNIST
 | 
			
		||||
class MNISTDataModule(pl.LightningDataModule):
 | 
			
		||||
 
 | 
			
		||||
@@ -6,8 +6,8 @@ from torch.nn.parameter import Parameter
 | 
			
		||||
from ..core.competitions import wtac
 | 
			
		||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
			
		||||
from ..core.initializers import EyeTransformInitializer
 | 
			
		||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
			
		||||
from ..nn.activations import get_activation
 | 
			
		||||
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
 | 
			
		||||
from ..core.transforms import LinearTransform
 | 
			
		||||
from ..nn.wrappers import LambdaLayer, LossLayer
 | 
			
		||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
			
		||||
 | 
			
		||||
@@ -18,15 +18,16 @@ class GLVQ(SupervisedPrototypeModel):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Default hparams
 | 
			
		||||
        self.hparams.setdefault("margin", 0.0)
 | 
			
		||||
        self.hparams.setdefault("transfer_fn", "identity")
 | 
			
		||||
        self.hparams.setdefault("transfer_beta", 10.0)
 | 
			
		||||
 | 
			
		||||
        # Layers
 | 
			
		||||
        transfer_fn = get_activation(self.hparams.transfer_fn)
 | 
			
		||||
        self.transfer_layer = LambdaLayer(transfer_fn)
 | 
			
		||||
 | 
			
		||||
        # Loss
 | 
			
		||||
        self.loss = LossLayer(glvq_loss)
 | 
			
		||||
        self.loss = GLVQLoss(
 | 
			
		||||
            margin=self.hparams.margin,
 | 
			
		||||
            transfer_fn=self.hparams.transfer_fn,
 | 
			
		||||
            beta=self.hparams.transfer_beta,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def initialize_prototype_win_ratios(self):
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
@@ -55,9 +56,7 @@ class GLVQ(SupervisedPrototypeModel):
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        out = self.compute_distances(x)
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        mu = self.loss(out, y, prototype_labels=plabels)
 | 
			
		||||
        batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
 | 
			
		||||
        loss = batch_loss.sum(dim=0)
 | 
			
		||||
        loss = self.loss(out, y, plabels)
 | 
			
		||||
        return out, loss
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
@@ -208,18 +207,22 @@ class SiameseGMLVQ(SiameseGLVQ):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Override the backbone
 | 
			
		||||
        self.backbone = torch.nn.Linear(self.hparams.input_dim,
 | 
			
		||||
                                        self.hparams.latent_dim,
 | 
			
		||||
                                        bias=False)
 | 
			
		||||
        omega_initializer = kwargs.get("omega_initializer",
 | 
			
		||||
                                       EyeTransformInitializer())
 | 
			
		||||
        self.backbone = LinearTransform(
 | 
			
		||||
            self.hparams.input_dim,
 | 
			
		||||
            self.hparams.output_dim,
 | 
			
		||||
            initializer=omega_initializer,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def omega_matrix(self):
 | 
			
		||||
        return self.backbone.weight.detach().cpu()
 | 
			
		||||
        return self.backbone.weights
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def lambda_matrix(self):
 | 
			
		||||
        omega = self.backbone.weight  # (latent_dim, input_dim)
 | 
			
		||||
        lam = omega.T @ omega
 | 
			
		||||
        omega = self.backbone.weight  # (input_dim, latent_dim)
 | 
			
		||||
        lam = omega @ omega.T
 | 
			
		||||
        return lam.detach().cpu()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,8 @@
 | 
			
		||||
"""LVQ models that are optimized using non-gradient methods."""
 | 
			
		||||
 | 
			
		||||
from ..core.losses import _get_dp_dm
 | 
			
		||||
from ..nn.activations import get_activation
 | 
			
		||||
from ..nn.wrappers import LambdaLayer
 | 
			
		||||
from .abstract import NonGradientMixin
 | 
			
		||||
from .glvq import GLVQ
 | 
			
		||||
 | 
			
		||||
@@ -66,4 +68,61 @@ class LVQ21(NonGradientMixin, GLVQ):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MedianLVQ(NonGradientMixin, GLVQ):
 | 
			
		||||
    """Median LVQ"""
 | 
			
		||||
    """Median LVQ
 | 
			
		||||
 | 
			
		||||
    # TODO Avoid computing distances over and over
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, hparams, verbose=True, **kwargs):
 | 
			
		||||
        self.verbose = verbose
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        self.transfer_layer = LambdaLayer(
 | 
			
		||||
            get_activation(self.hparams.transfer_fn))
 | 
			
		||||
 | 
			
		||||
    def _f(self, x, y, protos, plabels):
 | 
			
		||||
        d = self.distance_layer(x, protos)
 | 
			
		||||
        dp, dm = _get_dp_dm(d, y, plabels)
 | 
			
		||||
        mu = (dp - dm) / (dp + dm)
 | 
			
		||||
        invmu = -1.0 * mu
 | 
			
		||||
        f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0
 | 
			
		||||
        return f
 | 
			
		||||
 | 
			
		||||
    def expectation(self, x, y, protos, plabels):
 | 
			
		||||
        f = self._f(x, y, protos, plabels)
 | 
			
		||||
        gamma = f / f.sum()
 | 
			
		||||
        return gamma
 | 
			
		||||
 | 
			
		||||
    def lower_bound(self, x, y, protos, plabels, gamma):
 | 
			
		||||
        f = self._f(x, y, protos, plabels)
 | 
			
		||||
        lower_bound = (gamma * f.log()).sum()
 | 
			
		||||
        return lower_bound
 | 
			
		||||
 | 
			
		||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        protos = self.proto_layer.components
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
 | 
			
		||||
        x, y = train_batch
 | 
			
		||||
        dis = self.compute_distances(x)
 | 
			
		||||
 | 
			
		||||
        for i, _ in enumerate(protos):
 | 
			
		||||
            # Expectation step
 | 
			
		||||
            gamma = self.expectation(x, y, protos, plabels)
 | 
			
		||||
            lower_bound = self.lower_bound(x, y, protos, plabels, gamma)
 | 
			
		||||
 | 
			
		||||
            # Maximization step
 | 
			
		||||
            _protos = protos + 0
 | 
			
		||||
            for k, xk in enumerate(x):
 | 
			
		||||
                _protos[i] = xk
 | 
			
		||||
                _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
 | 
			
		||||
                if _lower_bound > lower_bound:
 | 
			
		||||
                    if self.verbose:
 | 
			
		||||
                        print(f"Updating prototype {i} to data {k}...")
 | 
			
		||||
                    self.proto_layer.load_state_dict({"_components": _protos},
 | 
			
		||||
                                                     strict=False)
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
        # Logging
 | 
			
		||||
        self.log_acc(dis, y, tag="train_acc")
 | 
			
		||||
 | 
			
		||||
        return None
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										58
									
								
								prototorch/models/nam.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								prototorch/models/nam.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,58 @@
 | 
			
		||||
"""ProtoTorch Neural Additive Model."""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torchmetrics
 | 
			
		||||
 | 
			
		||||
from .abstract import ProtoTorchBolt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BinaryNAM(ProtoTorchBolt):
 | 
			
		||||
    """Neural Additive Model for binary classification.
 | 
			
		||||
 | 
			
		||||
    Paper: https://arxiv.org/abs/2004.13912
 | 
			
		||||
    Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
 | 
			
		||||
                 **kwargs):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Default hparams
 | 
			
		||||
        self.hparams.setdefault("threshold", 0.5)
 | 
			
		||||
 | 
			
		||||
        self.extractors = extractors
 | 
			
		||||
        self.linear = torch.nn.Linear(in_features=len(extractors),
 | 
			
		||||
                                      out_features=1,
 | 
			
		||||
                                      bias=True)
 | 
			
		||||
 | 
			
		||||
    def extract(self, x):
 | 
			
		||||
        """Apply the local extractors batch-wise on features."""
 | 
			
		||||
        out = torch.zeros_like(x)
 | 
			
		||||
        for j in range(x.shape[1]):
 | 
			
		||||
            out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.extract(x)
 | 
			
		||||
        x = self.linear(x)
 | 
			
		||||
        return torch.sigmoid(x)
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        preds = self(x).squeeze()
 | 
			
		||||
        train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
 | 
			
		||||
        self.log("train_loss", train_loss)
 | 
			
		||||
        accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
 | 
			
		||||
        self.log("train_acc",
 | 
			
		||||
                 accuracy,
 | 
			
		||||
                 on_step=False,
 | 
			
		||||
                 on_epoch=True,
 | 
			
		||||
                 prog_bar=True,
 | 
			
		||||
                 logger=True)
 | 
			
		||||
        return train_loss
 | 
			
		||||
 | 
			
		||||
    def predict(self, x):
 | 
			
		||||
        out = self(x)
 | 
			
		||||
        pred = torch.zeros_like(out, device=self.device)
 | 
			
		||||
        pred[out > self.hparams.threshold] = 1
 | 
			
		||||
        return pred
 | 
			
		||||
@@ -1,5 +1,4 @@
 | 
			
		||||
"""Probabilistic GLVQ methods"""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ..core.losses import nllr_loss, rslvq_loss
 | 
			
		||||
@@ -24,7 +23,7 @@ class CELVQ(GLVQ):
 | 
			
		||||
        winning = stratified_min_pooling(out, plabels)  # [None, num_classes]
 | 
			
		||||
        probs = -1.0 * winning
 | 
			
		||||
        batch_loss = self.loss(probs, y.long())
 | 
			
		||||
        loss = batch_loss.sum(dim=0)
 | 
			
		||||
        loss = batch_loss.sum()
 | 
			
		||||
        return out, loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ):
 | 
			
		||||
    def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        self.conditional_distribution = None
 | 
			
		||||
        self.conditional_distribution = GaussianPrior(self.hparams.variance)
 | 
			
		||||
        self.rejection_confidence = rejection_confidence
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
@@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ):
 | 
			
		||||
        out = self.forward(x)
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        batch_loss = self.loss(out, y, plabels)
 | 
			
		||||
        loss = batch_loss.sum(dim=0)
 | 
			
		||||
        return loss
 | 
			
		||||
        train_loss = batch_loss.sum()
 | 
			
		||||
        self.log("train_loss", train_loss)
 | 
			
		||||
        return train_loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SLVQ(ProbabilisticLVQ):
 | 
			
		||||
@@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.loss = LossLayer(nllr_loss)
 | 
			
		||||
        self.conditional_distribution = GaussianPrior(self.hparams.variance)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RSLVQ(ProbabilisticLVQ):
 | 
			
		||||
@@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.loss = LossLayer(rslvq_loss)
 | 
			
		||||
        self.conditional_distribution = GaussianPrior(self.hparams.variance)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
 | 
			
		||||
@@ -92,5 +90,5 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
 | 
			
		||||
    #     x, y = batch
 | 
			
		||||
    #     y_pred = self(x)
 | 
			
		||||
    #     batch_loss = self.loss(y_pred, y)
 | 
			
		||||
    #     loss = batch_loss.sum(dim=0)
 | 
			
		||||
    #     loss = batch_loss.sum()
 | 
			
		||||
    #     return loss
 | 
			
		||||
 
 | 
			
		||||
@@ -132,7 +132,7 @@ class GrowingNeuralGas(NeuralGas):
 | 
			
		||||
        mask[torch.arange(len(mask)), winner] = 1.0
 | 
			
		||||
        dp = d * mask
 | 
			
		||||
 | 
			
		||||
        self.errors += torch.sum(dp * dp, dim=0)
 | 
			
		||||
        self.errors += torch.sum(dp * dp)
 | 
			
		||||
        self.errors *= self.hparams.step_reduction
 | 
			
		||||
 | 
			
		||||
        self.topology_layer(d)
 | 
			
		||||
 
 | 
			
		||||
@@ -117,6 +117,24 @@ class Vis2DAbstract(pl.Callback):
 | 
			
		||||
        plt.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Vis2D(Vis2DAbstract):
 | 
			
		||||
    def on_epoch_end(self, trainer, pl_module):
 | 
			
		||||
        if not self.precheck(trainer):
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        x_train, y_train = self.x_train, self.y_train
 | 
			
		||||
        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
			
		||||
                           ylabel="Data dimension 2")
 | 
			
		||||
        self.plot_data(ax, x_train, y_train)
 | 
			
		||||
        mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
 | 
			
		||||
        mesh_input = torch.from_numpy(mesh_input).type_as(x_train)
 | 
			
		||||
        y_pred = pl_module.predict(mesh_input)
 | 
			
		||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
			
		||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
			
		||||
 | 
			
		||||
        self.log_and_display(trainer, pl_module)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VisGLVQ2D(Vis2DAbstract):
 | 
			
		||||
    def on_epoch_end(self, trainer, pl_module):
 | 
			
		||||
        if not self.precheck(trainer):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user