From 823b05e39067d55722806b57acfa669589088097 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 14 Jul 2021 20:07:34 +0200 Subject: [PATCH] feat: add neural additive model for binary classification --- examples/binnam_tecator.py | 81 +++++++++++++++++++++++++++++++++++ prototorch/models/__init__.py | 1 + prototorch/models/nam.py | 44 +++++++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 examples/binnam_tecator.py create mode 100644 prototorch/models/nam.py diff --git a/examples/binnam_tecator.py b/examples/binnam_tecator.py new file mode 100644 index 0000000..8f8cc92 --- /dev/null +++ b/examples/binnam_tecator.py @@ -0,0 +1,81 @@ +"""Neural Additive Model (NAM) example for binary classification.""" + +import argparse + +import prototorch as pt +import pytorch_lightning as pl +import torch +from matplotlib import pyplot as plt + +if __name__ == "__main__": + # Command-line arguments + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # Dataset + train_ds = pt.datasets.Tecator("~/datasets") + + # Dataloaders + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) + + # Hyperparameters + hparams = dict(lr=0.1) + + # Define the feature extractor + class FE(torch.nn.Module): + def __init__(self): + super().__init__() + self.modules_list = torch.nn.ModuleList([ + torch.nn.Linear(1, 3), + torch.nn.Sigmoid(), + torch.nn.Linear(3, 1), + torch.nn.Sigmoid(), + ]) + + def forward(self, x): + for m in self.modules_list: + x = m(x) + return x + + # Initialize the model + model = pt.models.BinaryNAM( + hparams, + extractors=torch.nn.ModuleList([FE() for _ in range(100)]), + ) + + # Compute intermediate input and output sizes + model.example_input_array = torch.zeros(4, 100) + + # Callbacks + es = pl.callbacks.EarlyStopping( + monitor="train_loss", + min_delta=0.001, + patience=20, + mode="min", + verbose=True, + check_on_train_epoch_end=True, + ) + + # Setup trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=[ + es, + ], + terminate_on_nan=True, + weights_summary=None, + accelerator="ddp", + ) + + # Training loop + trainer.fit(model, train_loader) + + # Visualize extractor shape functions + fig, axes = plt.subplots(10, 10) + for i, ax in enumerate(axes.flat): + x = torch.linspace(-2, 2, 100) # TODO use min/max from data + y = model.extractors[i](x.view(100, 1)).squeeze().detach() + ax.plot(x, y) + ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[]) + plt.show() diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index 728c922..41c834d 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -19,6 +19,7 @@ from .glvq import ( ) from .knn import KNN from .lvq import LVQ1, LVQ21, MedianLVQ +from .nam import BinaryNAM from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas from .vis import * diff --git a/prototorch/models/nam.py b/prototorch/models/nam.py new file mode 100644 index 0000000..cb4efd2 --- /dev/null +++ b/prototorch/models/nam.py @@ -0,0 +1,44 @@ +"""ProtoTorch Neural Additive Model.""" + +import torch +import torchmetrics + +from .abstract import ProtoTorchBolt + + +class BinaryNAM(ProtoTorchBolt): + """Neural Additive Model for binary classification. + + Paper: https://arxiv.org/abs/2004.13912 + Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models + + """ + def __init__(self, hparams: dict, extractors: torch.nn.ModuleList, + **kwargs): + super().__init__(hparams, **kwargs) + self.extractors = extractors + + def extract(self, x): + """Apply the local extractors batch-wise on features.""" + out = torch.zeros_like(x) + for j in range(x.shape[1]): + out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze() + return out + + def forward(self, x): + x = self.extract(x).sum(1) + return torch.nn.functional.sigmoid(x) + + def training_step(self, batch, batch_idx, optimizer_idx=None): + x, y = batch + preds = self(x) + train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float()) + self.log("train_loss", train_loss) + accuracy = torchmetrics.functional.accuracy(preds.int(), y.int()) + self.log("train_acc", + accuracy, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True) + return train_loss