test(githooks): Add githooks for automatic commit checks
This commit is contained in:
@@ -1,10 +1,19 @@
|
||||
import prototorch as pt
|
||||
"""Prototorch Data Modules
|
||||
|
||||
This allows to store the used dataset inside a Lightning Module.
|
||||
Mainly used for PytorchLightningCLI configurations.
|
||||
"""
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
# MNIST
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def __init__(self, batch_size=32):
|
||||
super().__init__()
|
||||
@@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
|
||||
return mnist_test
|
||||
|
||||
|
||||
def train_on_mnist(batch_size=256) -> type:
|
||||
class DataClass(pl.LightningModule):
|
||||
datamodule = MNISTDataModule(batch_size=batch_size)
|
||||
# def train_on_mnist(batch_size=256) -> type:
|
||||
# class DataClass(pl.LightningModule):
|
||||
# datamodule = MNISTDataModule(batch_size=batch_size)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
prototype_initializer = kwargs.pop(
|
||||
"prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||
super().__init__(*args,
|
||||
prototype_initializer=prototype_initializer,
|
||||
**kwargs)
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# prototype_initializer = kwargs.pop(
|
||||
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
|
||||
# super().__init__(*args,
|
||||
# prototype_initializer=prototype_initializer,
|
||||
# **kwargs)
|
||||
|
||||
return DataClass
|
||||
# dc: Type[DataClass] = DataClass
|
||||
# return dc
|
||||
|
||||
|
||||
# ABSTRACT
|
||||
class GeneralDataModule(pl.LightningDataModule):
|
||||
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
|
||||
super().__init__()
|
||||
self.train_dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
return DataLoader(self.train_dataset, batch_size=self.batch_size)
|
||||
|
||||
|
||||
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
|
||||
# class DataClass(pl.LightningModule):
|
||||
# datamodule = GeneralDataModule(dataset, batch_size)
|
||||
# datashape = dataset[0][0].shape
|
||||
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
|
||||
|
||||
# def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
# prototype_initializer = kwargs.pop(
|
||||
# "prototype_initializer",
|
||||
# pt.components.Zeros(self.datashape),
|
||||
# )
|
||||
# super().__init__(*args,
|
||||
# prototype_initializer=prototype_initializer,
|
||||
# **kwargs)
|
||||
|
||||
# return DataClass
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# from prototorch.models import GLVQ
|
||||
|
||||
# demo_dataset = pt.datasets.Iris()
|
||||
|
||||
# TrainingClass: Type = train_on_dataset(demo_dataset)
|
||||
|
||||
# class DemoGLVQ(TrainingClass, GLVQ):
|
||||
# """Model Definition."""
|
||||
|
||||
# # Hyperparameters
|
||||
# hparams = dict(
|
||||
# distribution={
|
||||
# "num_classes": 3,
|
||||
# "prototypes_per_class": 4
|
||||
# },
|
||||
# lr=0.01,
|
||||
# )
|
||||
|
||||
# initialized = DemoGLVQ(hparams)
|
||||
# print(initialized)
|
||||
|
@@ -3,11 +3,8 @@
|
||||
import torch
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.distances import (lomega_distance, omega_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
|
Reference in New Issue
Block a user