[WIP] Add labels.py

This commit is contained in:
Jensun Ravichandran 2021-06-11 18:48:43 +02:00
parent c0c0044a42
commit 24903b761c
4 changed files with 108 additions and 15 deletions

View File

@ -1,2 +1,3 @@
from prototorch.components.components import *
from prototorch.components.initializers import *
from .components import *
from .initializers import *
from .labels import *

View File

@ -1,4 +1,4 @@
"""ProtoTorch components modules."""
"""ProtoTorch Components."""
import warnings
@ -13,7 +13,7 @@ from torch.nn.parameter import Parameter
from .initializers import parse_data_arg
def get_labels_object(distribution):
def get_labels_initializer(distribution):
if isinstance(distribution, dict):
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
@ -119,10 +119,11 @@ class LabeledComponents(Components):
components, component_labels = parse_data_arg(
initialized_components)
super().__init__(initialized_components=components)
# self._labels = component_labels
self._labels = component_labels
else:
labels = get_labels_object(distribution)
self.initial_distribution = labels.distribution
labels_initializer = get_labels_initializer(distribution)
self.initial_distribution = labels_initializer.distribution
_labels = labels.generate()
super().__init__(len(_labels), initializer=initializer)
self._register_labels(_labels)
@ -150,8 +151,8 @@ class LabeledComponents(Components):
_precheck_initializer(initializer)
# Labels
labels = get_labels_object(distribution)
new_labels = labels.generate()
labels_initializer = get_labels_initializer(distribution)
new_labels = labels_initializer.generate()
_labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels)
@ -196,20 +197,24 @@ class ReasoningComponents(Components):
"""
def __init__(self,
reasonings=None,
distribution=None,
initializer=None,
reasoning_initializer=None,
*,
initialized_components=None):
if initialized_components is not None:
components, reasonings = initialized_components
super().__init__(initialized_components=components)
self.register_parameter("_reasonings", reasonings)
else:
self._initialize_reasonings(reasonings)
super().__init__(len(self._reasonings), initializer=initializer)
labels_initializer = get_labels_initializer(distribution)
self.initial_distribution = labels_initializer.distribution
super().__init__(len(self.initial_distribution),
initializer=initializer)
reasonings = reasoning_initializer.generate()
self._register_reasonings(reasonings)
def _initialize_reasonings(self, reasonings):
def _initialize_reasonings(self, reasoning_initializer):
if isinstance(reasonings, tuple):
num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components)

View File

@ -2,6 +2,7 @@
import warnings
from collections.abc import Iterable
from itertools import chain
from typing import List
import torch
from torch.utils.data import DataLoader, Dataset
@ -179,7 +180,7 @@ class UnequalLabelsInitializer(LabelsInitializer):
self.clabels = clabels or range(len(self.dist))
@property
def distribution(self):
def distribution(self) -> List:
return self.dist
def generate(self):
@ -194,7 +195,7 @@ class EqualLabelsInitializer(LabelsInitializer):
self.per_class = per_class
@property
def distribution(self):
def distribution(self) -> List:
return self.classes * [self.per_class]
def generate(self):

View File

@ -0,0 +1,86 @@
"""ProtoTorch Labels."""
import torch
from prototorch.components.components import get_labels_initializer
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer)
from torch.nn.parameter import Parameter
def get_labels_initializer(distribution):
if isinstance(distribution, dict):
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
clabels = list(distribution.keys())
dist = list(distribution.values())
labels = UnequalLabelsInitializer(dist, clabels)
elif isinstance(distribution, tuple):
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif isinstance(distribution, list):
labels = UnequalLabelsInitializer(distribution)
else:
msg = f"`distribution` not understood." \
f"You have provided: {distribution=}."
raise ValueError(msg)
return labels
class Labels(torch.nn.Module):
def __init__(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
_labels = self.get_labels(distribution,
initializer,
initialized_labels=initialized_labels)
self._register_labels(_labels)
def _register_labels(self, labels):
# self.register_buffer("_labels", labels)
self.register_parameter("_labels",
Parameter(labels, requires_grad=False))
def get_labels(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
if initialized_labels is not None:
_labels = initialized_labels
else:
labels_initializer = initializer or get_labels_initializer(
distribution)
self.initial_distribution = labels_initializer.distribution
_labels = labels_initializer.generate()
return _labels
def add_labels(self,
distribution=None,
initializer=None,
*,
initialized_labels=None):
new_labels = self.get_labels(distribution,
initializer,
initialized_labels=initialized_labels)
_labels = torch.cat([self._labels, new_labels])
self._register_labels(_labels)
def remove_labels(self, indices=None):
mask = torch.ones(len(self._labels, dtype=torch.bool))
mask[indices] = False
_labels = self._labels[mask]
self._register_labels(_labels)
@property
def labels(self):
return self._labels
def forward(self):
return self._labels