feat: add neural additive model for binary classification
This commit is contained in:
parent
f8ad1d83eb
commit
823b05e390
81
examples/binnam_tecator.py
Normal file
81
examples/binnam_tecator.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Neural Additive Model (NAM) example for binary classification."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Tecator("~/datasets")
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(lr=0.1)
|
||||||
|
|
||||||
|
# Define the feature extractor
|
||||||
|
class FE(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.modules_list = torch.nn.ModuleList([
|
||||||
|
torch.nn.Linear(1, 3),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
torch.nn.Linear(3, 1),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for m in self.modules_list:
|
||||||
|
x = m(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.BinaryNAM(
|
||||||
|
hparams,
|
||||||
|
extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 100)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_loss",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=20,
|
||||||
|
mode="min",
|
||||||
|
verbose=True,
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[
|
||||||
|
es,
|
||||||
|
],
|
||||||
|
terminate_on_nan=True,
|
||||||
|
weights_summary=None,
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Visualize extractor shape functions
|
||||||
|
fig, axes = plt.subplots(10, 10)
|
||||||
|
for i, ax in enumerate(axes.flat):
|
||||||
|
x = torch.linspace(-2, 2, 100) # TODO use min/max from data
|
||||||
|
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
|
||||||
|
ax.plot(x, y)
|
||||||
|
ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
|
||||||
|
plt.show()
|
@ -19,6 +19,7 @@ from .glvq import (
|
|||||||
)
|
)
|
||||||
from .knn import KNN
|
from .knn import KNN
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
|
from .nam import BinaryNAM
|
||||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
44
prototorch/models/nam.py
Normal file
44
prototorch/models/nam.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""ProtoTorch Neural Additive Model."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
|
||||||
|
from .abstract import ProtoTorchBolt
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryNAM(ProtoTorchBolt):
|
||||||
|
"""Neural Additive Model for binary classification.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/2004.13912
|
||||||
|
Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
self.extractors = extractors
|
||||||
|
|
||||||
|
def extract(self, x):
|
||||||
|
"""Apply the local extractors batch-wise on features."""
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
for j in range(x.shape[1]):
|
||||||
|
out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.extract(x).sum(1)
|
||||||
|
return torch.nn.functional.sigmoid(x)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
x, y = batch
|
||||||
|
preds = self(x)
|
||||||
|
train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
|
||||||
|
self.log("train_loss", train_loss)
|
||||||
|
accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
|
||||||
|
self.log("train_acc",
|
||||||
|
accuracy,
|
||||||
|
on_step=False,
|
||||||
|
on_epoch=True,
|
||||||
|
prog_bar=True,
|
||||||
|
logger=True)
|
||||||
|
return train_loss
|
Loading…
Reference in New Issue
Block a user