Compare commits
	
		
			3 Commits
		
	
	
		
			feature/ux
			...
			wip/nam
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					aeb6417c28 | ||
| 
						 | 
					cb7fb91c95 | ||
| 
						 | 
					823b05e390 | 
							
								
								
									
										81
									
								
								examples/binnam_tecator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								examples/binnam_tecator.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
			
		||||
"""Neural Additive Model (NAM) example for binary classification."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from matplotlib import pyplot as plt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Dataset
 | 
			
		||||
    train_ds = pt.datasets.Tecator("~/datasets")
 | 
			
		||||
 | 
			
		||||
    # Dataloaders
 | 
			
		||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
 | 
			
		||||
 | 
			
		||||
    # Hyperparameters
 | 
			
		||||
    hparams = dict(lr=0.1)
 | 
			
		||||
 | 
			
		||||
    # Define the feature extractor
 | 
			
		||||
    class FE(torch.nn.Module):
 | 
			
		||||
        def __init__(self):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.modules_list = torch.nn.ModuleList([
 | 
			
		||||
                torch.nn.Linear(1, 3),
 | 
			
		||||
                torch.nn.Sigmoid(),
 | 
			
		||||
                torch.nn.Linear(3, 1),
 | 
			
		||||
                torch.nn.Sigmoid(),
 | 
			
		||||
            ])
 | 
			
		||||
 | 
			
		||||
        def forward(self, x):
 | 
			
		||||
            for m in self.modules_list:
 | 
			
		||||
                x = m(x)
 | 
			
		||||
            return x
 | 
			
		||||
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.BinaryNAM(
 | 
			
		||||
        hparams,
 | 
			
		||||
        extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Compute intermediate input and output sizes
 | 
			
		||||
    model.example_input_array = torch.zeros(4, 100)
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_loss",
 | 
			
		||||
        min_delta=0.001,
 | 
			
		||||
        patience=20,
 | 
			
		||||
        mode="min",
 | 
			
		||||
        verbose=True,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            es,
 | 
			
		||||
        ],
 | 
			
		||||
        terminate_on_nan=True,
 | 
			
		||||
        weights_summary=None,
 | 
			
		||||
        accelerator="ddp",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
    trainer.fit(model, train_loader)
 | 
			
		||||
 | 
			
		||||
    # Visualize extractor shape functions
 | 
			
		||||
    fig, axes = plt.subplots(10, 10)
 | 
			
		||||
    for i, ax in enumerate(axes.flat):
 | 
			
		||||
        x = torch.linspace(-2, 2, 100)  # TODO use min/max from data
 | 
			
		||||
        y = model.extractors[i](x.view(100, 1)).squeeze().detach()
 | 
			
		||||
        ax.plot(x, y)
 | 
			
		||||
        ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
 | 
			
		||||
    plt.show()
 | 
			
		||||
							
								
								
									
										86
									
								
								examples/binnam_xor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								examples/binnam_xor.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
"""Neural Additive Model (NAM) example for binary classification."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from matplotlib import pyplot as plt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Dataset
 | 
			
		||||
    train_ds = pt.datasets.XOR()
 | 
			
		||||
 | 
			
		||||
    # Dataloaders
 | 
			
		||||
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
 | 
			
		||||
 | 
			
		||||
    # Hyperparameters
 | 
			
		||||
    hparams = dict(lr=0.001)
 | 
			
		||||
 | 
			
		||||
    # Define the feature extractor
 | 
			
		||||
    class FE(torch.nn.Module):
 | 
			
		||||
        def __init__(self, hidden_size=10):
 | 
			
		||||
            super().__init__()
 | 
			
		||||
            self.modules_list = torch.nn.ModuleList([
 | 
			
		||||
                torch.nn.Linear(1, hidden_size),
 | 
			
		||||
                torch.nn.ReLU(),
 | 
			
		||||
                torch.nn.Linear(hidden_size, 1),
 | 
			
		||||
                torch.nn.ReLU(),
 | 
			
		||||
            ])
 | 
			
		||||
 | 
			
		||||
        def forward(self, x):
 | 
			
		||||
            for m in self.modules_list:
 | 
			
		||||
                x = m(x)
 | 
			
		||||
            return x
 | 
			
		||||
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.BinaryNAM(
 | 
			
		||||
        hparams,
 | 
			
		||||
        extractors=torch.nn.ModuleList([FE(20) for _ in range(2)]),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Compute intermediate input and output sizes
 | 
			
		||||
    model.example_input_array = torch.zeros(4, 2)
 | 
			
		||||
 | 
			
		||||
    # Summary
 | 
			
		||||
    print(model)
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    vis = pt.models.Vis2D(data=train_ds)
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_loss",
 | 
			
		||||
        min_delta=0.001,
 | 
			
		||||
        patience=50,
 | 
			
		||||
        mode="min",
 | 
			
		||||
        verbose=False,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            es,
 | 
			
		||||
        ],
 | 
			
		||||
        terminate_on_nan=True,
 | 
			
		||||
        weights_summary="full",
 | 
			
		||||
        accelerator="ddp",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
    trainer.fit(model, train_loader)
 | 
			
		||||
 | 
			
		||||
    # Visualize extractor shape functions
 | 
			
		||||
    fig, axes = plt.subplots(2)
 | 
			
		||||
    for i, ax in enumerate(axes.flat):
 | 
			
		||||
        x = torch.linspace(0, 1, 100)  # TODO use min/max from data
 | 
			
		||||
        y = model.extractors[i](x.view(100, 1)).squeeze().detach()
 | 
			
		||||
        ax.plot(x, y)
 | 
			
		||||
        ax.set(title=f"Feature {i + 1}")
 | 
			
		||||
    plt.show()
 | 
			
		||||
@@ -19,6 +19,7 @@ from .glvq import (
 | 
			
		||||
)
 | 
			
		||||
from .knn import KNN
 | 
			
		||||
from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
			
		||||
from .nam import BinaryNAM
 | 
			
		||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
 | 
			
		||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
 | 
			
		||||
from .vis import *
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										58
									
								
								prototorch/models/nam.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								prototorch/models/nam.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,58 @@
 | 
			
		||||
"""ProtoTorch Neural Additive Model."""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torchmetrics
 | 
			
		||||
 | 
			
		||||
from .abstract import ProtoTorchBolt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BinaryNAM(ProtoTorchBolt):
 | 
			
		||||
    """Neural Additive Model for binary classification.
 | 
			
		||||
 | 
			
		||||
    Paper: https://arxiv.org/abs/2004.13912
 | 
			
		||||
    Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
 | 
			
		||||
                 **kwargs):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Default hparams
 | 
			
		||||
        self.hparams.setdefault("threshold", 0.5)
 | 
			
		||||
 | 
			
		||||
        self.extractors = extractors
 | 
			
		||||
        self.linear = torch.nn.Linear(in_features=len(extractors),
 | 
			
		||||
                                      out_features=1,
 | 
			
		||||
                                      bias=True)
 | 
			
		||||
 | 
			
		||||
    def extract(self, x):
 | 
			
		||||
        """Apply the local extractors batch-wise on features."""
 | 
			
		||||
        out = torch.zeros_like(x)
 | 
			
		||||
        for j in range(x.shape[1]):
 | 
			
		||||
            out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.extract(x)
 | 
			
		||||
        x = self.linear(x)
 | 
			
		||||
        return torch.sigmoid(x)
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        preds = self(x).squeeze()
 | 
			
		||||
        train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
 | 
			
		||||
        self.log("train_loss", train_loss)
 | 
			
		||||
        accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
 | 
			
		||||
        self.log("train_acc",
 | 
			
		||||
                 accuracy,
 | 
			
		||||
                 on_step=False,
 | 
			
		||||
                 on_epoch=True,
 | 
			
		||||
                 prog_bar=True,
 | 
			
		||||
                 logger=True)
 | 
			
		||||
        return train_loss
 | 
			
		||||
 | 
			
		||||
    def predict(self, x):
 | 
			
		||||
        out = self(x)
 | 
			
		||||
        pred = torch.zeros_like(out, device=self.device)
 | 
			
		||||
        pred[out > self.hparams.threshold] = 1
 | 
			
		||||
        return pred
 | 
			
		||||
@@ -1,5 +1,4 @@
 | 
			
		||||
"""Probabilistic GLVQ methods"""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ..core.losses import nllr_loss, rslvq_loss
 | 
			
		||||
@@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ):
 | 
			
		||||
    def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        self.conditional_distribution = None
 | 
			
		||||
        self.conditional_distribution = GaussianPrior(self.hparams.variance)
 | 
			
		||||
        self.rejection_confidence = rejection_confidence
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
@@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ):
 | 
			
		||||
        out = self.forward(x)
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        batch_loss = self.loss(out, y, plabels)
 | 
			
		||||
        loss = batch_loss.sum()
 | 
			
		||||
        return loss
 | 
			
		||||
        train_loss = batch_loss.sum()
 | 
			
		||||
        self.log("train_loss", train_loss)
 | 
			
		||||
        return train_loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SLVQ(ProbabilisticLVQ):
 | 
			
		||||
@@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.loss = LossLayer(nllr_loss)
 | 
			
		||||
        self.conditional_distribution = GaussianPrior(self.hparams.variance)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RSLVQ(ProbabilisticLVQ):
 | 
			
		||||
@@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.loss = LossLayer(rslvq_loss)
 | 
			
		||||
        self.conditional_distribution = GaussianPrior(self.hparams.variance)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
 | 
			
		||||
 
 | 
			
		||||
@@ -117,6 +117,24 @@ class Vis2DAbstract(pl.Callback):
 | 
			
		||||
        plt.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Vis2D(Vis2DAbstract):
 | 
			
		||||
    def on_epoch_end(self, trainer, pl_module):
 | 
			
		||||
        if not self.precheck(trainer):
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        x_train, y_train = self.x_train, self.y_train
 | 
			
		||||
        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
			
		||||
                           ylabel="Data dimension 2")
 | 
			
		||||
        self.plot_data(ax, x_train, y_train)
 | 
			
		||||
        mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
 | 
			
		||||
        mesh_input = torch.from_numpy(mesh_input).type_as(x_train)
 | 
			
		||||
        y_pred = pl_module.predict(mesh_input)
 | 
			
		||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
			
		||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
			
		||||
 | 
			
		||||
        self.log_and_display(trainer, pl_module)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VisGLVQ2D(Vis2DAbstract):
 | 
			
		||||
    def on_epoch_end(self, trainer, pl_module):
 | 
			
		||||
        if not self.precheck(trainer):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user