Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
aeb6417c28 | ||
|
cb7fb91c95 | ||
|
823b05e390 |
81
examples/binnam_tecator.py
Normal file
81
examples/binnam_tecator.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Neural Additive Model (NAM) example for binary classification."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Tecator("~/datasets")
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(lr=0.1)
|
||||
|
||||
# Define the feature extractor
|
||||
class FE(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.modules_list = torch.nn.ModuleList([
|
||||
torch.nn.Linear(1, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
torch.nn.Linear(3, 1),
|
||||
torch.nn.Sigmoid(),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for m in self.modules_list:
|
||||
x = m(x)
|
||||
return x
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.BinaryNAM(
|
||||
hparams,
|
||||
extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 100)
|
||||
|
||||
# Callbacks
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
es,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
weights_summary=None,
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Visualize extractor shape functions
|
||||
fig, axes = plt.subplots(10, 10)
|
||||
for i, ax in enumerate(axes.flat):
|
||||
x = torch.linspace(-2, 2, 100) # TODO use min/max from data
|
||||
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
|
||||
ax.plot(x, y)
|
||||
ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
|
||||
plt.show()
|
86
examples/binnam_xor.py
Normal file
86
examples/binnam_xor.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Neural Additive Model (NAM) example for binary classification."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.XOR()
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(lr=0.001)
|
||||
|
||||
# Define the feature extractor
|
||||
class FE(torch.nn.Module):
|
||||
def __init__(self, hidden_size=10):
|
||||
super().__init__()
|
||||
self.modules_list = torch.nn.ModuleList([
|
||||
torch.nn.Linear(1, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(hidden_size, 1),
|
||||
torch.nn.ReLU(),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for m in self.modules_list:
|
||||
x = m(x)
|
||||
return x
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.BinaryNAM(
|
||||
hparams,
|
||||
extractors=torch.nn.ModuleList([FE(20) for _ in range(2)]),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.Vis2D(data=train_ds)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=50,
|
||||
mode="min",
|
||||
verbose=False,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
es,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Visualize extractor shape functions
|
||||
fig, axes = plt.subplots(2)
|
||||
for i, ax in enumerate(axes.flat):
|
||||
x = torch.linspace(0, 1, 100) # TODO use min/max from data
|
||||
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
|
||||
ax.plot(x, y)
|
||||
ax.set(title=f"Feature {i + 1}")
|
||||
plt.show()
|
@ -19,6 +19,7 @@ from .glvq import (
|
||||
)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .nam import BinaryNAM
|
||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||
from .vis import *
|
||||
|
58
prototorch/models/nam.py
Normal file
58
prototorch/models/nam.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""ProtoTorch Neural Additive Model."""
|
||||
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from .abstract import ProtoTorchBolt
|
||||
|
||||
|
||||
class BinaryNAM(ProtoTorchBolt):
|
||||
"""Neural Additive Model for binary classification.
|
||||
|
||||
Paper: https://arxiv.org/abs/2004.13912
|
||||
Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
|
||||
|
||||
"""
|
||||
def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
|
||||
**kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("threshold", 0.5)
|
||||
|
||||
self.extractors = extractors
|
||||
self.linear = torch.nn.Linear(in_features=len(extractors),
|
||||
out_features=1,
|
||||
bias=True)
|
||||
|
||||
def extract(self, x):
|
||||
"""Apply the local extractors batch-wise on features."""
|
||||
out = torch.zeros_like(x)
|
||||
for j in range(x.shape[1]):
|
||||
out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
x = self.extract(x)
|
||||
x = self.linear(x)
|
||||
return torch.sigmoid(x)
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
preds = self(x).squeeze()
|
||||
train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
|
||||
self.log("train_loss", train_loss)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
|
||||
self.log("train_acc",
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
return train_loss
|
||||
|
||||
def predict(self, x):
|
||||
out = self(x)
|
||||
pred = torch.zeros_like(out, device=self.device)
|
||||
pred[out > self.hparams.threshold] = 1
|
||||
return pred
|
@ -1,5 +1,4 @@
|
||||
"""Probabilistic GLVQ methods"""
|
||||
|
||||
import torch
|
||||
|
||||
from ..core.losses import nllr_loss, rslvq_loss
|
||||
@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ):
|
||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
self.conditional_distribution = None
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
self.rejection_confidence = rejection_confidence
|
||||
|
||||
def forward(self, x):
|
||||
@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ):
|
||||
out = self.forward(x)
|
||||
plabels = self.proto_layer.labels
|
||||
batch_loss = self.loss(out, y, plabels)
|
||||
loss = batch_loss.sum()
|
||||
return loss
|
||||
train_loss = batch_loss.sum()
|
||||
self.log("train_loss", train_loss)
|
||||
return train_loss
|
||||
|
||||
|
||||
class SLVQ(ProbabilisticLVQ):
|
||||
@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
|
@ -117,6 +117,24 @@ class Vis2DAbstract(pl.Callback):
|
||||
plt.close()
|
||||
|
||||
|
||||
class Vis2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2")
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(x_train)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
|
||||
class VisGLVQ2D(Vis2DAbstract):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
|
Loading…
Reference in New Issue
Block a user