[WIP] Add labels.py

This commit is contained in:
Jensun Ravichandran 2021-06-11 18:48:43 +02:00
parent c0c0044a42
commit 24903b761c
4 changed files with 108 additions and 15 deletions

View File

@ -1,2 +1,3 @@
from prototorch.components.components import * from .components import *
from prototorch.components.initializers import * from .initializers import *
from .labels import *

View File

@ -1,4 +1,4 @@
"""ProtoTorch components modules.""" """ProtoTorch Components."""
import warnings import warnings
@ -13,7 +13,7 @@ from torch.nn.parameter import Parameter
from .initializers import parse_data_arg from .initializers import parse_data_arg
def get_labels_object(distribution): def get_labels_initializer(distribution):
if isinstance(distribution, dict): if isinstance(distribution, dict):
if "num_classes" in distribution.keys(): if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer( labels = EqualLabelsInitializer(
@ -119,10 +119,11 @@ class LabeledComponents(Components):
components, component_labels = parse_data_arg( components, component_labels = parse_data_arg(
initialized_components) initialized_components)
super().__init__(initialized_components=components) super().__init__(initialized_components=components)
# self._labels = component_labels
self._labels = component_labels self._labels = component_labels
else: else:
labels = get_labels_object(distribution) labels_initializer = get_labels_initializer(distribution)
self.initial_distribution = labels.distribution self.initial_distribution = labels_initializer.distribution
_labels = labels.generate() _labels = labels.generate()
super().__init__(len(_labels), initializer=initializer) super().__init__(len(_labels), initializer=initializer)
self._register_labels(_labels) self._register_labels(_labels)
@ -150,8 +151,8 @@ class LabeledComponents(Components):
_precheck_initializer(initializer) _precheck_initializer(initializer)
# Labels # Labels
labels = get_labels_object(distribution) labels_initializer = get_labels_initializer(distribution)
new_labels = labels.generate() new_labels = labels_initializer.generate()
_labels = torch.cat([self._labels, new_labels]) _labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels) self._register_labels(_labels)
@ -196,20 +197,24 @@ class ReasoningComponents(Components):
""" """
def __init__(self, def __init__(self,
reasonings=None, distribution=None,
initializer=None, initializer=None,
reasoning_initializer=None,
*, *,
initialized_components=None): initialized_components=None):
if initialized_components is not None: if initialized_components is not None:
components, reasonings = initialized_components components, reasonings = initialized_components
super().__init__(initialized_components=components) super().__init__(initialized_components=components)
self.register_parameter("_reasonings", reasonings) self.register_parameter("_reasonings", reasonings)
else: else:
self._initialize_reasonings(reasonings) labels_initializer = get_labels_initializer(distribution)
super().__init__(len(self._reasonings), initializer=initializer) self.initial_distribution = labels_initializer.distribution
super().__init__(len(self.initial_distribution),
initializer=initializer)
reasonings = reasoning_initializer.generate()
self._register_reasonings(reasonings)
def _initialize_reasonings(self, reasonings): def _initialize_reasonings(self, reasoning_initializer):
if isinstance(reasonings, tuple): if isinstance(reasonings, tuple):
num_classes, num_components = reasonings num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components) reasonings = ZeroReasoningsInitializer(num_classes, num_components)

View File

@ -2,6 +2,7 @@
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from itertools import chain from itertools import chain
from typing import List
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
@ -179,7 +180,7 @@ class UnequalLabelsInitializer(LabelsInitializer):
self.clabels = clabels or range(len(self.dist)) self.clabels = clabels or range(len(self.dist))
@property @property
def distribution(self): def distribution(self) -> List:
return self.dist return self.dist
def generate(self): def generate(self):
@ -194,7 +195,7 @@ class EqualLabelsInitializer(LabelsInitializer):
self.per_class = per_class self.per_class = per_class
@property @property
def distribution(self): def distribution(self) -> List:
return self.classes * [self.per_class] return self.classes * [self.per_class]
def generate(self): def generate(self):

View File

@ -0,0 +1,86 @@
"""ProtoTorch Labels."""
import torch
from prototorch.components.components import get_labels_initializer
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer)
from torch.nn.parameter import Parameter
def get_labels_initializer(distribution):
if isinstance(distribution, dict):
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
clabels = list(distribution.keys())
dist = list(distribution.values())
labels = UnequalLabelsInitializer(dist, clabels)
elif isinstance(distribution, tuple):
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif isinstance(distribution, list):
labels = UnequalLabelsInitializer(distribution)
else:
msg = f"`distribution` not understood." \
f"You have provided: {distribution=}."
raise ValueError(msg)
return labels
class Labels(torch.nn.Module):
def __init__(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
_labels = self.get_labels(distribution,
initializer,
initialized_labels=initialized_labels)
self._register_labels(_labels)
def _register_labels(self, labels):
# self.register_buffer("_labels", labels)
self.register_parameter("_labels",
Parameter(labels, requires_grad=False))
def get_labels(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
if initialized_labels is not None:
_labels = initialized_labels
else:
labels_initializer = initializer or get_labels_initializer(
distribution)
self.initial_distribution = labels_initializer.distribution
_labels = labels_initializer.generate()
return _labels
def add_labels(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
new_labels = self.get_labels(distribution,
initializer,
initialized_labels=initialized_labels)
_labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels)
def remove_labels(self, indices=None):
mask = torch.ones(len(self._labels, dtype=torch.bool))
mask[indices] = False
_labels = self._labels[mask]
self._register_labels(_labels)
@property
def labels(self):
return self._labels
def forward(self):
return self._labels