chore: rename clc-lc to proto-Y-architecture

This commit is contained in:
Alexander Engelsberger 2022-05-18 14:11:46 +02:00
parent 02954044d7
commit dc4f31d700
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
5 changed files with 94 additions and 75 deletions

View File

@ -20,7 +20,7 @@ import torch
from torchmetrics import Accuracy, Metric
class CLCCScheme(pl.LightningModule):
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:

View File

@ -1,20 +1,12 @@
from typing import Optional, Type
import numpy as np
import prototorch as pt
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core import SMCI
from prototorch.models.clcc.clcc_glvq import GLVQ
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
# NEW STUFF
# ##############################################################################
class LogTorchmetricCallback(pl.Callback):
@ -34,7 +26,7 @@ class LogTorchmetricCallback(pl.Callback):
def setup(
self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
@ -69,65 +61,3 @@ class VisGLVQ2D(Vis2DAbstract):
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# TODO: Pruning
# ##############################################################################
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
train_ds.targets[train_ds.targets == 2.0] = 1.0
# Dataloaders
train_loader = DataLoader(
train_ds,
batch_size=64,
num_workers=0,
shuffle=True,
)
components_initializer = SMCI(train_ds)
#components_initializer = RandomNormalCompInitializer(2)
hyperparameters = GLVQ.HyperParameters(
lr=0.5,
distribution=dict(
num_classes=2,
per_class=1,
),
component_initializer=components_initializer,
)
model = GLVQ(hyperparameters)
print(model)
# Callbacks
vis = VisGLVQ2D(data=train_ds)
recall = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
num_classes=2,
)
es = EarlyStopping(
monitor="recall",
min_delta=0.001,
patience=15,
mode="max",
check_on_train_epoch_end=True,
)
# Train
trainer = pl.Trainer(
callbacks=[
vis,
recall,
es,
],
gpus=0,
max_epochs=200,
log_every_n_steps=1,
)
trainer.fit(model, train_loader)

View File

@ -10,11 +10,11 @@ from prototorch.core.initializers import (
LabelsInitializer,
)
from prototorch.core.losses import GLVQLoss
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
class SupervisedScheme(CLCCScheme):
class SupervisedScheme(BaseYArchitecture):
@dataclass
class HyperParameters:

View File

@ -0,0 +1,89 @@
import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.models.proto_y_architecture.callbacks import (
LogTorchmetricCallback,
VisGLVQ2D,
)
from prototorch.models.proto_y_architecture.glvq import GLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
# ##############################################################################
if __name__ == "__main__":
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
train_ds.targets[train_ds.targets == 2.0] = 1.0
# Dataloader
train_loader = DataLoader(
train_ds,
batch_size=64,
num_workers=0,
shuffle=True,
)
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
# Select Initializer
components_initializer = SMCI(train_ds)
# Define Hyperparameters
hyperparameters = GLVQ.HyperParameters(
lr=0.5,
distribution=dict(
num_classes=2,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
model = GLVQ(hyperparameters)
print(model)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
stopping_criterion = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
num_classes=2,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
min_delta=0.001,
patience=15,
mode="max",
check_on_train_epoch_end=True,
)
# Visualization Callback
vis = VisGLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
stopping_criterion,
es,
],
gpus=0,
max_epochs=200,
log_every_n_steps=1,
)
# Train
trainer.fit(model, train_loader)