[WIP] Add labels.py
This commit is contained in:
parent
c0c0044a42
commit
24903b761c
@ -1,2 +1,3 @@
|
|||||||
from prototorch.components.components import *
|
from .components import *
|
||||||
from prototorch.components.initializers import *
|
from .initializers import *
|
||||||
|
from .labels import *
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""ProtoTorch components modules."""
|
"""ProtoTorch Components."""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from .initializers import parse_data_arg
|
from .initializers import parse_data_arg
|
||||||
|
|
||||||
|
|
||||||
def get_labels_object(distribution):
|
def get_labels_initializer(distribution):
|
||||||
if isinstance(distribution, dict):
|
if isinstance(distribution, dict):
|
||||||
if "num_classes" in distribution.keys():
|
if "num_classes" in distribution.keys():
|
||||||
labels = EqualLabelsInitializer(
|
labels = EqualLabelsInitializer(
|
||||||
@ -119,10 +119,11 @@ class LabeledComponents(Components):
|
|||||||
components, component_labels = parse_data_arg(
|
components, component_labels = parse_data_arg(
|
||||||
initialized_components)
|
initialized_components)
|
||||||
super().__init__(initialized_components=components)
|
super().__init__(initialized_components=components)
|
||||||
|
# self._labels = component_labels
|
||||||
self._labels = component_labels
|
self._labels = component_labels
|
||||||
else:
|
else:
|
||||||
labels = get_labels_object(distribution)
|
labels_initializer = get_labels_initializer(distribution)
|
||||||
self.initial_distribution = labels.distribution
|
self.initial_distribution = labels_initializer.distribution
|
||||||
_labels = labels.generate()
|
_labels = labels.generate()
|
||||||
super().__init__(len(_labels), initializer=initializer)
|
super().__init__(len(_labels), initializer=initializer)
|
||||||
self._register_labels(_labels)
|
self._register_labels(_labels)
|
||||||
@ -150,8 +151,8 @@ class LabeledComponents(Components):
|
|||||||
_precheck_initializer(initializer)
|
_precheck_initializer(initializer)
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
labels = get_labels_object(distribution)
|
labels_initializer = get_labels_initializer(distribution)
|
||||||
new_labels = labels.generate()
|
new_labels = labels_initializer.generate()
|
||||||
_labels = torch.cat([self._labels, new_labels])
|
_labels = torch.cat([self._labels, new_labels])
|
||||||
self._register_labels(_labels)
|
self._register_labels(_labels)
|
||||||
|
|
||||||
@ -196,20 +197,24 @@ class ReasoningComponents(Components):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
reasonings=None,
|
distribution=None,
|
||||||
initializer=None,
|
initializer=None,
|
||||||
|
reasoning_initializer=None,
|
||||||
*,
|
*,
|
||||||
initialized_components=None):
|
initialized_components=None):
|
||||||
if initialized_components is not None:
|
if initialized_components is not None:
|
||||||
components, reasonings = initialized_components
|
components, reasonings = initialized_components
|
||||||
|
|
||||||
super().__init__(initialized_components=components)
|
super().__init__(initialized_components=components)
|
||||||
self.register_parameter("_reasonings", reasonings)
|
self.register_parameter("_reasonings", reasonings)
|
||||||
else:
|
else:
|
||||||
self._initialize_reasonings(reasonings)
|
labels_initializer = get_labels_initializer(distribution)
|
||||||
super().__init__(len(self._reasonings), initializer=initializer)
|
self.initial_distribution = labels_initializer.distribution
|
||||||
|
super().__init__(len(self.initial_distribution),
|
||||||
|
initializer=initializer)
|
||||||
|
reasonings = reasoning_initializer.generate()
|
||||||
|
self._register_reasonings(reasonings)
|
||||||
|
|
||||||
def _initialize_reasonings(self, reasonings):
|
def _initialize_reasonings(self, reasoning_initializer):
|
||||||
if isinstance(reasonings, tuple):
|
if isinstance(reasonings, tuple):
|
||||||
num_classes, num_components = reasonings
|
num_classes, num_components = reasonings
|
||||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
@ -179,7 +180,7 @@ class UnequalLabelsInitializer(LabelsInitializer):
|
|||||||
self.clabels = clabels or range(len(self.dist))
|
self.clabels = clabels or range(len(self.dist))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distribution(self):
|
def distribution(self) -> List:
|
||||||
return self.dist
|
return self.dist
|
||||||
|
|
||||||
def generate(self):
|
def generate(self):
|
||||||
@ -194,7 +195,7 @@ class EqualLabelsInitializer(LabelsInitializer):
|
|||||||
self.per_class = per_class
|
self.per_class = per_class
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distribution(self):
|
def distribution(self) -> List:
|
||||||
return self.classes * [self.per_class]
|
return self.classes * [self.per_class]
|
||||||
|
|
||||||
def generate(self):
|
def generate(self):
|
||||||
|
86
prototorch/components/labels.py
Normal file
86
prototorch/components/labels.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""ProtoTorch Labels."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.components.components import get_labels_initializer
|
||||||
|
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||||
|
ComponentsInitializer,
|
||||||
|
EqualLabelsInitializer,
|
||||||
|
UnequalLabelsInitializer)
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
def get_labels_initializer(distribution):
|
||||||
|
if isinstance(distribution, dict):
|
||||||
|
if "num_classes" in distribution.keys():
|
||||||
|
labels = EqualLabelsInitializer(
|
||||||
|
distribution["num_classes"],
|
||||||
|
distribution["prototypes_per_class"])
|
||||||
|
else:
|
||||||
|
clabels = list(distribution.keys())
|
||||||
|
dist = list(distribution.values())
|
||||||
|
labels = UnequalLabelsInitializer(dist, clabels)
|
||||||
|
elif isinstance(distribution, tuple):
|
||||||
|
num_classes, prototypes_per_class = distribution
|
||||||
|
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||||
|
elif isinstance(distribution, list):
|
||||||
|
labels = UnequalLabelsInitializer(distribution)
|
||||||
|
else:
|
||||||
|
msg = f"`distribution` not understood." \
|
||||||
|
f"You have provided: {distribution=}."
|
||||||
|
raise ValueError(msg)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
class Labels(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
distribution=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_labels=None):
|
||||||
|
_labels = self.get_labels(distribution,
|
||||||
|
initializer,
|
||||||
|
initialized_labels=initialized_labels)
|
||||||
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
def _register_labels(self, labels):
|
||||||
|
# self.register_buffer("_labels", labels)
|
||||||
|
self.register_parameter("_labels",
|
||||||
|
Parameter(labels, requires_grad=False))
|
||||||
|
|
||||||
|
def get_labels(self,
|
||||||
|
distribution=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_labels=None):
|
||||||
|
if initialized_labels is not None:
|
||||||
|
_labels = initialized_labels
|
||||||
|
else:
|
||||||
|
labels_initializer = initializer or get_labels_initializer(
|
||||||
|
distribution)
|
||||||
|
self.initial_distribution = labels_initializer.distribution
|
||||||
|
_labels = labels_initializer.generate()
|
||||||
|
return _labels
|
||||||
|
|
||||||
|
def add_labels(self,
|
||||||
|
distribution=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_labels=None):
|
||||||
|
new_labels = self.get_labels(distribution,
|
||||||
|
initializer,
|
||||||
|
initialized_labels=initialized_labels)
|
||||||
|
_labels = torch.cat([self._labels, new_labels])
|
||||||
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
def remove_labels(self, indices=None):
|
||||||
|
mask = torch.ones(len(self._labels, dtype=torch.bool))
|
||||||
|
mask[indices] = False
|
||||||
|
_labels = self._labels[mask]
|
||||||
|
self._register_labels(_labels)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
return self._labels
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self._labels
|
Loading…
Reference in New Issue
Block a user