Compare commits

...

3 Commits

Author SHA1 Message Date
Jensun Ravichandran
aeb6417c28
refactor: minor changes in probabilistic.py 2021-08-06 13:49:29 +02:00
Jensun Ravichandran
cb7fb91c95
feat: add binnam_xor.py 2021-07-15 18:19:28 +02:00
Jensun Ravichandran
823b05e390
feat: add neural additive model for binary classification 2021-07-14 20:07:34 +02:00
6 changed files with 248 additions and 6 deletions

View File

@ -0,0 +1,81 @@
"""Neural Additive Model (NAM) example for binary classification."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Tecator("~/datasets")
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(lr=0.1)
# Define the feature extractor
class FE(torch.nn.Module):
def __init__(self):
super().__init__()
self.modules_list = torch.nn.ModuleList([
torch.nn.Linear(1, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
])
def forward(self, x):
for m in self.modules_list:
x = m(x)
return x
# Initialize the model
model = pt.models.BinaryNAM(
hparams,
extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 100)
# Callbacks
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
es,
],
terminate_on_nan=True,
weights_summary=None,
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Visualize extractor shape functions
fig, axes = plt.subplots(10, 10)
for i, ax in enumerate(axes.flat):
x = torch.linspace(-2, 2, 100) # TODO use min/max from data
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
ax.plot(x, y)
ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
plt.show()

86
examples/binnam_xor.py Normal file
View File

@ -0,0 +1,86 @@
"""Neural Additive Model (NAM) example for binary classification."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.XOR()
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
# Hyperparameters
hparams = dict(lr=0.001)
# Define the feature extractor
class FE(torch.nn.Module):
def __init__(self, hidden_size=10):
super().__init__()
self.modules_list = torch.nn.ModuleList([
torch.nn.Linear(1, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 1),
torch.nn.ReLU(),
])
def forward(self, x):
for m in self.modules_list:
x = m(x)
return x
# Initialize the model
model = pt.models.BinaryNAM(
hparams,
extractors=torch.nn.ModuleList([FE(20) for _ in range(2)]),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.Vis2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=50,
mode="min",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
terminate_on_nan=True,
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Visualize extractor shape functions
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes.flat):
x = torch.linspace(0, 1, 100) # TODO use min/max from data
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
ax.plot(x, y)
ax.set(title=f"Feature {i + 1}")
plt.show()

View File

@ -19,6 +19,7 @@ from .glvq import (
) )
from .knn import KNN from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .nam import BinaryNAM
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import * from .vis import *

58
prototorch/models/nam.py Normal file
View File

@ -0,0 +1,58 @@
"""ProtoTorch Neural Additive Model."""
import torch
import torchmetrics
from .abstract import ProtoTorchBolt
class BinaryNAM(ProtoTorchBolt):
"""Neural Additive Model for binary classification.
Paper: https://arxiv.org/abs/2004.13912
Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
"""
def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
**kwargs):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("threshold", 0.5)
self.extractors = extractors
self.linear = torch.nn.Linear(in_features=len(extractors),
out_features=1,
bias=True)
def extract(self, x):
"""Apply the local extractors batch-wise on features."""
out = torch.zeros_like(x)
for j in range(x.shape[1]):
out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
return out
def forward(self, x):
x = self.extract(x)
x = self.linear(x)
return torch.sigmoid(x)
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
preds = self(x).squeeze()
train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
self.log("train_loss", train_loss)
accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
self.log("train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
return train_loss
def predict(self, x):
out = self(x)
pred = torch.zeros_like(out, device=self.device)
pred[out > self.hparams.threshold] = 1
return pred

View File

@ -1,5 +1,4 @@
"""Probabilistic GLVQ methods""" """Probabilistic GLVQ methods"""
import torch import torch
from ..core.losses import nllr_loss, rslvq_loss from ..core.losses import nllr_loss, rslvq_loss
@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs): def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.conditional_distribution = None self.conditional_distribution = GaussianPrior(self.hparams.variance)
self.rejection_confidence = rejection_confidence self.rejection_confidence = rejection_confidence
def forward(self, x): def forward(self, x):
@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ):
out = self.forward(x) out = self.forward(x)
plabels = self.proto_layer.labels plabels = self.proto_layer.labels
batch_loss = self.loss(out, y, plabels) batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum() train_loss = batch_loss.sum()
return loss self.log("train_loss", train_loss)
return train_loss
class SLVQ(ProbabilisticLVQ): class SLVQ(ProbabilisticLVQ):
@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss = LossLayer(nllr_loss) self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ): class RSLVQ(ProbabilisticLVQ):
@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.loss = LossLayer(rslvq_loss) self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):

View File

@ -117,6 +117,24 @@ class Vis2DAbstract(pl.Callback):
plt.close() plt.close()
class Vis2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
mesh_input = torch.from_numpy(mesh_input).type_as(x_train)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.precheck(trainer):