Merge pull request #2 from si-cim/dev

Use Components instead of Prototypes and refactor old examples
This commit is contained in:
Jensun Ravichandran 2021-04-29 17:03:17 +02:00 committed by GitHub
commit c50f139559
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 216 additions and 235 deletions

View File

@ -35,7 +35,7 @@ workon pt
git clone git@github.com:si-cim/prototorch_models.git git clone git@github.com:si-cim/prototorch_models.git
cd prototorch_models cd prototorch_models
git checkout dev git checkout dev
pip install -e .[all] # \[all\] if you are using zsh pip install -e .[all] # \[all\] if you are using zsh or MacOS
``` ```
To assist in the development process, you may also find it useful to install To assist in the development process, you may also find it useful to install

View File

@ -1,63 +1,14 @@
"""GLVQ example using the Iris dataset.""" """GLVQ example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from prototorch.components import initializers as cinit
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisGLVQ2D
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
x_train, y_train = load_iris(return_X_y=True) x_train, y_train = load_iris(return_X_y=True)
@ -69,24 +20,21 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
input_dim=x_train.shape[1],
nclasses=3, nclasses=3,
prototypes_per_class=3, prototypes_per_class=2,
prototype_initializer="stratified_mean", prototype_initializer=cinit.StratifiedMeanInitializer(
lr=0.1, torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01,
) )
# Initialize the model # Initialize the model
model = GLVQ(hparams, data=[x_train, y_train]) model = GLVQ(hparams, data=[x_train, y_train])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=50, callbacks=[vis]) trainer = pl.Trainer(
max_epochs=50,
callbacks=[VisGLVQ2D(x_train, y_train)],
)
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

View File

@ -3,63 +3,13 @@
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisNG2D
from prototorch.models.neural_gas import NeuralGas
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.neural_gas import NeuralGas
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Neural Gas Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module: NeuralGas):
protos = pl_module.proto_layer.prototypes.detach().cpu().numpy()
cmat = pl_module.topology_layer.cmat.cpu().numpy()
# Visualize the data and the prototypes
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(self.x_train[:, 0],
self.x_train[:, 1],
c=self.y_train,
edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
# Draw connections
for i in range(len(protos)):
for j in range(len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
plt.pause(0.01)
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
x_train, y_train = load_iris(return_X_y=True) x_train, y_train = load_iris(return_X_y=True)
@ -68,7 +18,6 @@ if __name__ == "__main__":
scaler.fit(x_train) scaler.fit(x_train)
x_train = scaler.transform(x_train) x_train = scaler.transform(x_train)
y_single_class = np.zeros_like(y_train)
train_ds = NumpyDataset(x_train, y_train) train_ds = NumpyDataset(x_train, y_train)
# Dataloaders # Dataloaders
@ -77,20 +26,18 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
input_dim=x_train.shape[1], input_dim=x_train.shape[1],
nclasses=1, num_prototypes=30,
prototypes_per_class=30, lr=0.01,
prototype_initializer="rand",
lr=0.1,
) )
# Initialize the model # Initialize the model
model = NeuralGas(hparams, data=[x_train, y_single_class]) model = NeuralGas(hparams)
# Model summary # Model summary
print(model) print(model)
# Callbacks # Callbacks
vis = VisualizationCallback(x_train, y_train) vis = VisNG2D(x_train, y_train)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(

View File

@ -1,70 +1,15 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset.""" """Siamese GLVQ example using all four dimensions of the Iris dataset."""
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from prototorch.components import (StratifiedMeanInitializer,
StratifiedSelectionInitializer)
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisSiameseGLVQ2D
from prototorch.models.glvq import SiameseGLVQ
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 0.2, x[:, 0].max() + 0.2
y_min, y_max = x[:, 1].min() - 0.2, x[:, 1].max() + 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2): def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@ -90,23 +35,24 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
input_dim=x_train.shape[1],
nclasses=3, nclasses=3,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer=StratifiedMeanInitializer(
torch.Tensor(x_train), torch.Tensor(y_train)),
lr=0.01, lr=0.01,
) )
# Initialize the model # Initialize the model
model = SiameseGLVQ(hparams, model = SiameseGLVQ(
hparams,
backbone_module=Backbone, backbone_module=Backbone,
data=[x_train, y_train]) )
# Model summary # Model summary
print(model) print(model)
# Callbacks # Callbacks
vis = VisualizationCallback(x_train, y_train) vis = VisSiameseGLVQ2D(x_train, y_train)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) trainer = pl.Trainer(max_epochs=100, callbacks=[vis])

View File

@ -1,16 +1,17 @@
import os import os
import numpy as np import numpy as np
import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string from prototorch.utils.utils import (gif_from_dir, make_directory,
prettify_string)
class VisWeights(Callback): class VisWeights(pl.Callback):
"""Abstract weight visualization callback.""" """Abstract weight visualization callback."""
def __init__( def __init__(
self, self,
@ -258,3 +259,155 @@ class VisPointProtos(VisWeights):
epoch = trainer.current_epoch epoch = trainer.current_epoch
self._display_logs(self.ax, epoch, logs) self._display_logs(self.ax, epoch, logs)
self._show_and_save(epoch) self._show_and_save(epoch)
class VisGLVQ2D(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
class VisSiameseGLVQ2D(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)
class VisNG2D(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Neural Gas Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
# Visualize the data and the prototypes
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(self.x_train[:, 0],
self.x_train[:, 1],
c=self.y_train,
edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
# Draw connections
for i in range(len(protos)):
for j in range(len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
plt.pause(0.01)

View File

@ -1,14 +1,16 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.functions.losses import glvq_loss from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
from .abstract import AbstractPrototypeModel
class GLVQ(pl.LightningModule):
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
@ -18,29 +20,18 @@ class GLVQ(pl.LightningModule):
# Default Values # Default Values
self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("distance", euclidean_distance)
self.proto_layer = Prototypes1D( self.proto_layer = LabeledComponents(
input_dim=self.hparams.input_dim, labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
nclasses=self.hparams.nclasses, initializer=self.hparams.prototype_initializer)
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
self.train_acc = torchmetrics.Accuracy() self.train_acc = torchmetrics.Accuracy()
@property
def prototypes(self):
return self.proto_layer.prototypes.detach().numpy()
@property @property
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.prototype_labels.detach().numpy() return self.proto_layer.component_labels.detach().numpy()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
def forward(self, x): def forward(self, x):
protos = self.proto_layer.prototypes protos, _ = self.proto_layer()
dis = self.hparams.distance(x, protos) dis = self.hparams.distance(x, protos)
return dis return dis
@ -48,7 +39,7 @@ class GLVQ(pl.LightningModule):
x, y = train_batch x, y = train_batch
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
dis = self(x) dis = self(x)
plabels = self.proto_layer.prototype_labels plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels) mu = glvq_loss(dis, y, prototype_labels=plabels)
loss = mu.sum(dim=0) loss = mu.sum(dim=0)
self.log("train_loss", loss) self.log("train_loss", loss)
@ -77,7 +68,7 @@ class GLVQ(pl.LightningModule):
# model.eval() # ?! # model.eval() # ?!
with torch.no_grad(): with torch.no_grad():
d = self(x) d = self(x)
plabels = self.proto_layer.prototype_labels plabels = self.proto_layer.component_labels
y_pred = wtac(d, plabels) y_pred = wtac(d, plabels)
return y_pred.numpy() return y_pred.numpy()
@ -89,7 +80,7 @@ class ImageGLVQ(GLVQ):
clamping after updates. clamping after updates.
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0.0, 1.0) self.proto_layer.components.data.clamp_(0.0, 1.0)
class SiameseGLVQ(GLVQ): class SiameseGLVQ(GLVQ):
@ -115,7 +106,7 @@ class SiameseGLVQ(GLVQ):
def forward(self, x): def forward(self, x):
self.sync_backbones() self.sync_backbones()
protos = self.proto_layer.prototypes protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos) latent_protos = self.backbone_dependent(protos)
@ -126,9 +117,8 @@ class SiameseGLVQ(GLVQ):
def predict_latent(self, x): def predict_latent(self, x):
# model.eval() # ?! # model.eval() # ?!
with torch.no_grad(): with torch.no_grad():
protos = self.proto_layer.prototypes protos, plabels = self.proto_layer()
latent_protos = self.backbone_dependent(protos) latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos) d = euclidean_distance(x, latent_protos)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels) y_pred = wtac(d, plabels)
return y_pred.numpy() return y_pred.numpy()

View File

@ -1,10 +1,13 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.components import Components
from prototorch.components import initializers as cinit
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules import Prototypes1D from prototorch.modules import Prototypes1D
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
from .abstract import AbstractPrototypeModel
class EuclideanDistance(torch.nn.Module): class EuclideanDistance(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):
@ -34,41 +37,35 @@ class ConnectionTopology(torch.nn.Module):
return f"agelimit: {self.agelimit}" return f"agelimit: {self.agelimit}"
class NeuralGas(pl.LightningModule): class NeuralGas(AbstractPrototypeModel):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
# Default Values # Default Values
self.hparams.setdefault("input_dim", 2)
self.hparams.setdefault("agelimit", 10) self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1) self.hparams.setdefault("lm", 1)
self.hparams.setdefault("prototype_initializer", "zeros") self.hparams.setdefault("prototype_initializer",
cinit.ZerosInitializer(self.hparams.input_dim))
self.proto_layer = Prototypes1D( self.proto_layer = Components(
input_dim=self.hparams.input_dim, self.hparams.num_prototypes,
nclasses=self.hparams.nclasses, initializer=self.hparams.prototype_initializer)
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs,
)
self.distance_layer = EuclideanDistance() self.distance_layer = EuclideanDistance()
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm) self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology( self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit, agelimit=self.hparams.agelimit,
num_prototypes=len(self.proto_layer.prototypes), num_prototypes=self.hparams.num_prototypes,
) )
def training_step(self, train_batch, batch_idx): def training_step(self, train_batch, batch_idx):
x, _ = train_batch x = train_batch[0]
protos, _ = self.proto_layer() protos = self.proto_layer()
d = self.distance_layer(x, protos) d = self.distance_layer(x, protos)
cost, order = self.energy_layer(d) cost, order = self.energy_layer(d)
self.topology_layer(d) self.topology_layer(d)
return cost return cost
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer