chore: rename clc-lc to proto-Y-architecture
This commit is contained in:
		@@ -20,7 +20,7 @@ import torch
 | 
				
			|||||||
from torchmetrics import Accuracy, Metric
 | 
					from torchmetrics import Accuracy, Metric
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CLCCScheme(pl.LightningModule):
 | 
					class BaseYArchitecture(pl.LightningModule):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @dataclass
 | 
					    @dataclass
 | 
				
			||||||
    class HyperParameters:
 | 
					    class HyperParameters:
 | 
				
			||||||
@@ -1,20 +1,12 @@
 | 
				
			|||||||
from typing import Optional, Type
 | 
					from typing import Optional, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
from prototorch.core import SMCI
 | 
					from prototorch.models.proto_y_architecture.base import BaseYArchitecture
 | 
				
			||||||
from prototorch.models.clcc.clcc_glvq import GLVQ
 | 
					 | 
				
			||||||
from prototorch.models.clcc.clcc_scheme import CLCCScheme
 | 
					 | 
				
			||||||
from prototorch.models.vis import Vis2DAbstract
 | 
					from prototorch.models.vis import Vis2DAbstract
 | 
				
			||||||
from prototorch.utils.utils import mesh2d
 | 
					from prototorch.utils.utils import mesh2d
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# NEW STUFF
 | 
					 | 
				
			||||||
# ##############################################################################
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LogTorchmetricCallback(pl.Callback):
 | 
					class LogTorchmetricCallback(pl.Callback):
 | 
				
			||||||
@@ -34,7 +26,7 @@ class LogTorchmetricCallback(pl.Callback):
 | 
				
			|||||||
    def setup(
 | 
					    def setup(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        trainer: pl.Trainer,
 | 
					        trainer: pl.Trainer,
 | 
				
			||||||
        pl_module: CLCCScheme,
 | 
					        pl_module: BaseYArchitecture,
 | 
				
			||||||
        stage: Optional[str] = None,
 | 
					        stage: Optional[str] = None,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        if self.on == "prediction":
 | 
					        if self.on == "prediction":
 | 
				
			||||||
@@ -69,65 +61,3 @@ class VisGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        y_pred = pl_module.predict(mesh_input)
 | 
					        y_pred = pl_module.predict(mesh_input)
 | 
				
			||||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
				
			||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# TODO: Pruning
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# ##############################################################################
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    # Dataset
 | 
					 | 
				
			||||||
    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
					 | 
				
			||||||
    train_ds.targets[train_ds.targets == 2.0] = 1.0
 | 
					 | 
				
			||||||
    # Dataloaders
 | 
					 | 
				
			||||||
    train_loader = DataLoader(
 | 
					 | 
				
			||||||
        train_ds,
 | 
					 | 
				
			||||||
        batch_size=64,
 | 
					 | 
				
			||||||
        num_workers=0,
 | 
					 | 
				
			||||||
        shuffle=True,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    components_initializer = SMCI(train_ds)
 | 
					 | 
				
			||||||
    #components_initializer = RandomNormalCompInitializer(2)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    hyperparameters = GLVQ.HyperParameters(
 | 
					 | 
				
			||||||
        lr=0.5,
 | 
					 | 
				
			||||||
        distribution=dict(
 | 
					 | 
				
			||||||
            num_classes=2,
 | 
					 | 
				
			||||||
            per_class=1,
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
        component_initializer=components_initializer,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    model = GLVQ(hyperparameters)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    print(model)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Callbacks
 | 
					 | 
				
			||||||
    vis = VisGLVQ2D(data=train_ds)
 | 
					 | 
				
			||||||
    recall = LogTorchmetricCallback(
 | 
					 | 
				
			||||||
        'recall',
 | 
					 | 
				
			||||||
        torchmetrics.Recall,
 | 
					 | 
				
			||||||
        num_classes=2,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    es = EarlyStopping(
 | 
					 | 
				
			||||||
        monitor="recall",
 | 
					 | 
				
			||||||
        min_delta=0.001,
 | 
					 | 
				
			||||||
        patience=15,
 | 
					 | 
				
			||||||
        mode="max",
 | 
					 | 
				
			||||||
        check_on_train_epoch_end=True,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Train
 | 
					 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					 | 
				
			||||||
        callbacks=[
 | 
					 | 
				
			||||||
            vis,
 | 
					 | 
				
			||||||
            recall,
 | 
					 | 
				
			||||||
            es,
 | 
					 | 
				
			||||||
        ],
 | 
					 | 
				
			||||||
        gpus=0,
 | 
					 | 
				
			||||||
        max_epochs=200,
 | 
					 | 
				
			||||||
        log_every_n_steps=1,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    trainer.fit(model, train_loader)
 | 
					 | 
				
			||||||
@@ -10,11 +10,11 @@ from prototorch.core.initializers import (
 | 
				
			|||||||
    LabelsInitializer,
 | 
					    LabelsInitializer,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from prototorch.core.losses import GLVQLoss
 | 
					from prototorch.core.losses import GLVQLoss
 | 
				
			||||||
from prototorch.models.clcc.clcc_scheme import CLCCScheme
 | 
					from prototorch.models.proto_y_architecture.base import BaseYArchitecture
 | 
				
			||||||
from prototorch.nn.wrappers import LambdaLayer
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SupervisedScheme(CLCCScheme):
 | 
					class SupervisedScheme(BaseYArchitecture):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @dataclass
 | 
					    @dataclass
 | 
				
			||||||
    class HyperParameters:
 | 
					    class HyperParameters:
 | 
				
			||||||
@@ -0,0 +1,89 @@
 | 
				
			|||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					from prototorch.core import SMCI
 | 
				
			||||||
 | 
					from prototorch.models.proto_y_architecture.callbacks import (
 | 
				
			||||||
 | 
					    LogTorchmetricCallback,
 | 
				
			||||||
 | 
					    VisGLVQ2D,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from prototorch.models.proto_y_architecture.glvq import GLVQ
 | 
				
			||||||
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ##############################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ------------------------------------------------------------
 | 
				
			||||||
 | 
					    # DATA
 | 
				
			||||||
 | 
					    # ------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
 | 
					    train_ds.targets[train_ds.targets == 2.0] = 1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Dataloader
 | 
				
			||||||
 | 
					    train_loader = DataLoader(
 | 
				
			||||||
 | 
					        train_ds,
 | 
				
			||||||
 | 
					        batch_size=64,
 | 
				
			||||||
 | 
					        num_workers=0,
 | 
				
			||||||
 | 
					        shuffle=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ------------------------------------------------------------
 | 
				
			||||||
 | 
					    # HYPERPARAMETERS
 | 
				
			||||||
 | 
					    # ------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Select Initializer
 | 
				
			||||||
 | 
					    components_initializer = SMCI(train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Define Hyperparameters
 | 
				
			||||||
 | 
					    hyperparameters = GLVQ.HyperParameters(
 | 
				
			||||||
 | 
					        lr=0.5,
 | 
				
			||||||
 | 
					        distribution=dict(
 | 
				
			||||||
 | 
					            num_classes=2,
 | 
				
			||||||
 | 
					            per_class=1,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        component_initializer=components_initializer,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create Model
 | 
				
			||||||
 | 
					    model = GLVQ(hyperparameters)
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ------------------------------------------------------------
 | 
				
			||||||
 | 
					    # TRAINING
 | 
				
			||||||
 | 
					    # ------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Controlling Callbacks
 | 
				
			||||||
 | 
					    stopping_criterion = LogTorchmetricCallback(
 | 
				
			||||||
 | 
					        'recall',
 | 
				
			||||||
 | 
					        torchmetrics.Recall,
 | 
				
			||||||
 | 
					        num_classes=2,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    es = EarlyStopping(
 | 
				
			||||||
 | 
					        monitor=stopping_criterion.name,
 | 
				
			||||||
 | 
					        min_delta=0.001,
 | 
				
			||||||
 | 
					        patience=15,
 | 
				
			||||||
 | 
					        mode="max",
 | 
				
			||||||
 | 
					        check_on_train_epoch_end=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Visualization Callback
 | 
				
			||||||
 | 
					    vis = VisGLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Define trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					            stopping_criterion,
 | 
				
			||||||
 | 
					            es,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        gpus=0,
 | 
				
			||||||
 | 
					        max_epochs=200,
 | 
				
			||||||
 | 
					        log_every_n_steps=1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Train
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
		Reference in New Issue
	
	Block a user