feat: add neural additive model for binary classification

This commit is contained in:
Jensun Ravichandran 2021-07-14 20:07:34 +02:00
parent f8ad1d83eb
commit 823b05e390
No known key found for this signature in database
GPG Key ID: 3331B0F18B6D4D93
3 changed files with 126 additions and 0 deletions

View File

@ -0,0 +1,81 @@
"""Neural Additive Model (NAM) example for binary classification."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Tecator("~/datasets")
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(lr=0.1)
# Define the feature extractor
class FE(torch.nn.Module):
def __init__(self):
super().__init__()
self.modules_list = torch.nn.ModuleList([
torch.nn.Linear(1, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
])
def forward(self, x):
for m in self.modules_list:
x = m(x)
return x
# Initialize the model
model = pt.models.BinaryNAM(
hparams,
extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 100)
# Callbacks
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
es,
],
terminate_on_nan=True,
weights_summary=None,
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Visualize extractor shape functions
fig, axes = plt.subplots(10, 10)
for i, ax in enumerate(axes.flat):
x = torch.linspace(-2, 2, 100) # TODO use min/max from data
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
ax.plot(x, y)
ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
plt.show()

View File

@ -19,6 +19,7 @@ from .glvq import (
)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ
from .nam import BinaryNAM
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import *

44
prototorch/models/nam.py Normal file
View File

@ -0,0 +1,44 @@
"""ProtoTorch Neural Additive Model."""
import torch
import torchmetrics
from .abstract import ProtoTorchBolt
class BinaryNAM(ProtoTorchBolt):
"""Neural Additive Model for binary classification.
Paper: https://arxiv.org/abs/2004.13912
Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
"""
def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
**kwargs):
super().__init__(hparams, **kwargs)
self.extractors = extractors
def extract(self, x):
"""Apply the local extractors batch-wise on features."""
out = torch.zeros_like(x)
for j in range(x.shape[1]):
out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
return out
def forward(self, x):
x = self.extract(x).sum(1)
return torch.nn.functional.sigmoid(x)
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
preds = self(x)
train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
self.log("train_loss", train_loss)
accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
self.log("train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
return train_loss