chore: rename clc-lc to proto-Y-architecture
This commit is contained in:
parent
02954044d7
commit
dc4f31d700
@ -20,7 +20,7 @@ import torch
|
|||||||
from torchmetrics import Accuracy, Metric
|
from torchmetrics import Accuracy, Metric
|
||||||
|
|
||||||
|
|
||||||
class CLCCScheme(pl.LightningModule):
|
class BaseYArchitecture(pl.LightningModule):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HyperParameters:
|
class HyperParameters:
|
@ -1,20 +1,12 @@
|
|||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core import SMCI
|
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
|
||||||
from prototorch.models.clcc.clcc_glvq import GLVQ
|
|
||||||
from prototorch.models.clcc.clcc_scheme import CLCCScheme
|
|
||||||
from prototorch.models.vis import Vis2DAbstract
|
from prototorch.models.vis import Vis2DAbstract
|
||||||
from prototorch.utils.utils import mesh2d
|
from prototorch.utils.utils import mesh2d
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# NEW STUFF
|
|
||||||
# ##############################################################################
|
|
||||||
|
|
||||||
|
|
||||||
class LogTorchmetricCallback(pl.Callback):
|
class LogTorchmetricCallback(pl.Callback):
|
||||||
@ -34,7 +26,7 @@ class LogTorchmetricCallback(pl.Callback):
|
|||||||
def setup(
|
def setup(
|
||||||
self,
|
self,
|
||||||
trainer: pl.Trainer,
|
trainer: pl.Trainer,
|
||||||
pl_module: CLCCScheme,
|
pl_module: BaseYArchitecture,
|
||||||
stage: Optional[str] = None,
|
stage: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.on == "prediction":
|
if self.on == "prediction":
|
||||||
@ -69,65 +61,3 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
y_pred = pl_module.predict(mesh_input)
|
y_pred = pl_module.predict(mesh_input)
|
||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Pruning
|
|
||||||
|
|
||||||
# ##############################################################################
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
|
||||||
train_ds.targets[train_ds.targets == 2.0] = 1.0
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_ds,
|
|
||||||
batch_size=64,
|
|
||||||
num_workers=0,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
components_initializer = SMCI(train_ds)
|
|
||||||
#components_initializer = RandomNormalCompInitializer(2)
|
|
||||||
|
|
||||||
hyperparameters = GLVQ.HyperParameters(
|
|
||||||
lr=0.5,
|
|
||||||
distribution=dict(
|
|
||||||
num_classes=2,
|
|
||||||
per_class=1,
|
|
||||||
),
|
|
||||||
component_initializer=components_initializer,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = GLVQ(hyperparameters)
|
|
||||||
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = VisGLVQ2D(data=train_ds)
|
|
||||||
recall = LogTorchmetricCallback(
|
|
||||||
'recall',
|
|
||||||
torchmetrics.Recall,
|
|
||||||
num_classes=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
es = EarlyStopping(
|
|
||||||
monitor="recall",
|
|
||||||
min_delta=0.001,
|
|
||||||
patience=15,
|
|
||||||
mode="max",
|
|
||||||
check_on_train_epoch_end=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
recall,
|
|
||||||
es,
|
|
||||||
],
|
|
||||||
gpus=0,
|
|
||||||
max_epochs=200,
|
|
||||||
log_every_n_steps=1,
|
|
||||||
)
|
|
||||||
trainer.fit(model, train_loader)
|
|
@ -10,11 +10,11 @@ from prototorch.core.initializers import (
|
|||||||
LabelsInitializer,
|
LabelsInitializer,
|
||||||
)
|
)
|
||||||
from prototorch.core.losses import GLVQLoss
|
from prototorch.core.losses import GLVQLoss
|
||||||
from prototorch.models.clcc.clcc_scheme import CLCCScheme
|
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
|
||||||
from prototorch.nn.wrappers import LambdaLayer
|
from prototorch.nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
class SupervisedScheme(CLCCScheme):
|
class SupervisedScheme(BaseYArchitecture):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HyperParameters:
|
class HyperParameters:
|
@ -0,0 +1,89 @@
|
|||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torchmetrics
|
||||||
|
from prototorch.core import SMCI
|
||||||
|
from prototorch.models.proto_y_architecture.callbacks import (
|
||||||
|
LogTorchmetricCallback,
|
||||||
|
VisGLVQ2D,
|
||||||
|
)
|
||||||
|
from prototorch.models.proto_y_architecture.glvq import GLVQ
|
||||||
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# ##############################################################################
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# DATA
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
|
train_ds.targets[train_ds.targets == 2.0] = 1.0
|
||||||
|
|
||||||
|
# Dataloader
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_ds,
|
||||||
|
batch_size=64,
|
||||||
|
num_workers=0,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# HYPERPARAMETERS
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
# Select Initializer
|
||||||
|
components_initializer = SMCI(train_ds)
|
||||||
|
|
||||||
|
# Define Hyperparameters
|
||||||
|
hyperparameters = GLVQ.HyperParameters(
|
||||||
|
lr=0.5,
|
||||||
|
distribution=dict(
|
||||||
|
num_classes=2,
|
||||||
|
per_class=1,
|
||||||
|
),
|
||||||
|
component_initializer=components_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create Model
|
||||||
|
model = GLVQ(hyperparameters)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# TRAINING
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
# Controlling Callbacks
|
||||||
|
stopping_criterion = LogTorchmetricCallback(
|
||||||
|
'recall',
|
||||||
|
torchmetrics.Recall,
|
||||||
|
num_classes=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
es = EarlyStopping(
|
||||||
|
monitor=stopping_criterion.name,
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=15,
|
||||||
|
mode="max",
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Visualization Callback
|
||||||
|
vis = VisGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
|
# Define trainer
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
stopping_criterion,
|
||||||
|
es,
|
||||||
|
],
|
||||||
|
gpus=0,
|
||||||
|
max_epochs=200,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
trainer.fit(model, train_loader)
|
Loading…
Reference in New Issue
Block a user