prototorch_models/prototorch/models/unsupervised.py

85 lines
2.6 KiB
Python
Raw Normal View History

2021-05-21 13:42:45 +00:00
"""Unsupervised prototype learning algorithms."""
2021-06-01 15:19:43 +00:00
import logging
2021-05-21 13:42:45 +00:00
import warnings
import prototorch as pt
2021-06-01 15:19:43 +00:00
import pytorch_lightning as pl
2021-04-23 15:30:23 +00:00
import torch
2021-05-21 13:42:45 +00:00
import torchmetrics
from prototorch.components import Components, LabeledComponents
2021-06-04 20:20:32 +00:00
from prototorch.components.initializers import ZerosInitializer
2021-05-21 13:42:45 +00:00
from prototorch.functions.competitions import knnc
2021-04-23 15:30:23 +00:00
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import LambdaLayer
2021-04-23 15:30:23 +00:00
from prototorch.modules.losses import NeuralGasEnergy
2021-06-01 15:19:43 +00:00
from pytorch_lightning.callbacks import Callback
2021-04-23 15:30:23 +00:00
2021-06-04 20:20:32 +00:00
from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
2021-04-23 15:30:23 +00:00
2021-06-04 20:20:32 +00:00
class NeuralGas(UnsupervisedPrototypeModel):
2021-04-23 15:30:23 +00:00
def __init__(self, hparams, **kwargs):
2021-06-04 20:20:32 +00:00
super().__init__(hparams, **kwargs)
2021-04-23 15:30:23 +00:00
2021-06-04 20:20:32 +00:00
# Hyperparameters
2021-04-23 15:30:23 +00:00
self.save_hyperparameters(hparams)
2021-06-04 20:20:32 +00:00
# Default hparams
self.hparams.setdefault("input_dim", 2)
2021-04-23 15:30:23 +00:00
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1)
2021-04-23 15:30:23 +00:00
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit,
num_prototypes=self.hparams.num_prototypes,
2021-04-23 15:30:23 +00:00
)
def training_step(self, train_batch, batch_idx):
2021-06-04 20:20:32 +00:00
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
2021-06-04 20:20:32 +00:00
d = self.compute_distances(x)
cost, _ = self.energy_layer(d)
2021-04-23 15:30:23 +00:00
self.topology_layer(d)
return cost
2021-06-01 15:19:43 +00:00
class GrowingNeuralGas(NeuralGas):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
2021-06-04 20:20:32 +00:00
# Defaults
2021-06-01 15:19:43 +00:00
self.hparams.setdefault("step_reduction", 0.5)
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
2021-06-04 20:20:32 +00:00
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
self.register_buffer("errors", errors)
2021-06-01 15:19:43 +00:00
def training_step(self, train_batch, _batch_idx):
2021-06-04 20:20:32 +00:00
# x = train_batch
# TODO Check if the batch has labels
2021-06-01 15:19:43 +00:00
x = train_batch[0]
2021-06-04 20:20:32 +00:00
d = self.compute_distances(x)
2021-06-01 15:19:43 +00:00
cost, order = self.energy_layer(d)
winner = order[:, 0]
mask = torch.zeros_like(d)
mask[torch.arange(len(mask)), winner] = 1.0
2021-06-04 20:20:32 +00:00
dp = d * mask
2021-06-01 15:19:43 +00:00
2021-06-04 20:20:32 +00:00
self.errors += torch.sum(dp * dp, dim=0)
2021-06-01 15:19:43 +00:00
self.errors *= self.hparams.step_reduction
self.topology_layer(d)
return cost
def configure_callbacks(self):
return [
GNGCallback(reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq)
]