feat: add neural additive model for binary classification
This commit is contained in:
		
							
								
								
									
										81
									
								
								examples/binnam_tecator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								examples/binnam_tecator.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,81 @@ | |||||||
|  | """Neural Additive Model (NAM) example for binary classification.""" | ||||||
|  |  | ||||||
|  | import argparse | ||||||
|  |  | ||||||
|  | import prototorch as pt | ||||||
|  | import pytorch_lightning as pl | ||||||
|  | import torch | ||||||
|  | from matplotlib import pyplot as plt | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     # Command-line arguments | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser = pl.Trainer.add_argparse_args(parser) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     # Dataset | ||||||
|  |     train_ds = pt.datasets.Tecator("~/datasets") | ||||||
|  |  | ||||||
|  |     # Dataloaders | ||||||
|  |     train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) | ||||||
|  |  | ||||||
|  |     # Hyperparameters | ||||||
|  |     hparams = dict(lr=0.1) | ||||||
|  |  | ||||||
|  |     # Define the feature extractor | ||||||
|  |     class FE(torch.nn.Module): | ||||||
|  |         def __init__(self): | ||||||
|  |             super().__init__() | ||||||
|  |             self.modules_list = torch.nn.ModuleList([ | ||||||
|  |                 torch.nn.Linear(1, 3), | ||||||
|  |                 torch.nn.Sigmoid(), | ||||||
|  |                 torch.nn.Linear(3, 1), | ||||||
|  |                 torch.nn.Sigmoid(), | ||||||
|  |             ]) | ||||||
|  |  | ||||||
|  |         def forward(self, x): | ||||||
|  |             for m in self.modules_list: | ||||||
|  |                 x = m(x) | ||||||
|  |             return x | ||||||
|  |  | ||||||
|  |     # Initialize the model | ||||||
|  |     model = pt.models.BinaryNAM( | ||||||
|  |         hparams, | ||||||
|  |         extractors=torch.nn.ModuleList([FE() for _ in range(100)]), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Compute intermediate input and output sizes | ||||||
|  |     model.example_input_array = torch.zeros(4, 100) | ||||||
|  |  | ||||||
|  |     # Callbacks | ||||||
|  |     es = pl.callbacks.EarlyStopping( | ||||||
|  |         monitor="train_loss", | ||||||
|  |         min_delta=0.001, | ||||||
|  |         patience=20, | ||||||
|  |         mode="min", | ||||||
|  |         verbose=True, | ||||||
|  |         check_on_train_epoch_end=True, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Setup trainer | ||||||
|  |     trainer = pl.Trainer.from_argparse_args( | ||||||
|  |         args, | ||||||
|  |         callbacks=[ | ||||||
|  |             es, | ||||||
|  |         ], | ||||||
|  |         terminate_on_nan=True, | ||||||
|  |         weights_summary=None, | ||||||
|  |         accelerator="ddp", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Training loop | ||||||
|  |     trainer.fit(model, train_loader) | ||||||
|  |  | ||||||
|  |     # Visualize extractor shape functions | ||||||
|  |     fig, axes = plt.subplots(10, 10) | ||||||
|  |     for i, ax in enumerate(axes.flat): | ||||||
|  |         x = torch.linspace(-2, 2, 100)  # TODO use min/max from data | ||||||
|  |         y = model.extractors[i](x.view(100, 1)).squeeze().detach() | ||||||
|  |         ax.plot(x, y) | ||||||
|  |         ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[]) | ||||||
|  |     plt.show() | ||||||
| @@ -19,6 +19,7 @@ from .glvq import ( | |||||||
| ) | ) | ||||||
| from .knn import KNN | from .knn import KNN | ||||||
| from .lvq import LVQ1, LVQ21, MedianLVQ | from .lvq import LVQ1, LVQ21, MedianLVQ | ||||||
|  | from .nam import BinaryNAM | ||||||
| from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ | from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ | ||||||
| from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas | from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas | ||||||
| from .vis import * | from .vis import * | ||||||
|   | |||||||
							
								
								
									
										44
									
								
								prototorch/models/nam.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								prototorch/models/nam.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | |||||||
|  | """ProtoTorch Neural Additive Model.""" | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import torchmetrics | ||||||
|  |  | ||||||
|  | from .abstract import ProtoTorchBolt | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BinaryNAM(ProtoTorchBolt): | ||||||
|  |     """Neural Additive Model for binary classification. | ||||||
|  |  | ||||||
|  |     Paper: https://arxiv.org/abs/2004.13912 | ||||||
|  |     Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models | ||||||
|  |  | ||||||
|  |     """ | ||||||
|  |     def __init__(self, hparams: dict, extractors: torch.nn.ModuleList, | ||||||
|  |                  **kwargs): | ||||||
|  |         super().__init__(hparams, **kwargs) | ||||||
|  |         self.extractors = extractors | ||||||
|  |  | ||||||
|  |     def extract(self, x): | ||||||
|  |         """Apply the local extractors batch-wise on features.""" | ||||||
|  |         out = torch.zeros_like(x) | ||||||
|  |         for j in range(x.shape[1]): | ||||||
|  |             out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze() | ||||||
|  |         return out | ||||||
|  |  | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.extract(x).sum(1) | ||||||
|  |         return torch.nn.functional.sigmoid(x) | ||||||
|  |  | ||||||
|  |     def training_step(self, batch, batch_idx, optimizer_idx=None): | ||||||
|  |         x, y = batch | ||||||
|  |         preds = self(x) | ||||||
|  |         train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float()) | ||||||
|  |         self.log("train_loss", train_loss) | ||||||
|  |         accuracy = torchmetrics.functional.accuracy(preds.int(), y.int()) | ||||||
|  |         self.log("train_acc", | ||||||
|  |                  accuracy, | ||||||
|  |                  on_step=False, | ||||||
|  |                  on_epoch=True, | ||||||
|  |                  prog_bar=True, | ||||||
|  |                  logger=True) | ||||||
|  |         return train_loss | ||||||
		Reference in New Issue
	
	Block a user