[WIP] Add labels.py
This commit is contained in:
		@@ -1,2 +1,3 @@
 | 
				
			|||||||
from prototorch.components.components import *
 | 
					from .components import *
 | 
				
			||||||
from prototorch.components.initializers import *
 | 
					from .initializers import *
 | 
				
			||||||
 | 
					from .labels import *
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
"""ProtoTorch components modules."""
 | 
					"""ProtoTorch Components."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -13,7 +13,7 @@ from torch.nn.parameter import Parameter
 | 
				
			|||||||
from .initializers import parse_data_arg
 | 
					from .initializers import parse_data_arg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_labels_object(distribution):
 | 
					def get_labels_initializer(distribution):
 | 
				
			||||||
    if isinstance(distribution, dict):
 | 
					    if isinstance(distribution, dict):
 | 
				
			||||||
        if "num_classes" in distribution.keys():
 | 
					        if "num_classes" in distribution.keys():
 | 
				
			||||||
            labels = EqualLabelsInitializer(
 | 
					            labels = EqualLabelsInitializer(
 | 
				
			||||||
@@ -119,10 +119,11 @@ class LabeledComponents(Components):
 | 
				
			|||||||
            components, component_labels = parse_data_arg(
 | 
					            components, component_labels = parse_data_arg(
 | 
				
			||||||
                initialized_components)
 | 
					                initialized_components)
 | 
				
			||||||
            super().__init__(initialized_components=components)
 | 
					            super().__init__(initialized_components=components)
 | 
				
			||||||
 | 
					            # self._labels = component_labels
 | 
				
			||||||
            self._labels = component_labels
 | 
					            self._labels = component_labels
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            labels = get_labels_object(distribution)
 | 
					            labels_initializer = get_labels_initializer(distribution)
 | 
				
			||||||
            self.initial_distribution = labels.distribution
 | 
					            self.initial_distribution = labels_initializer.distribution
 | 
				
			||||||
            _labels = labels.generate()
 | 
					            _labels = labels.generate()
 | 
				
			||||||
            super().__init__(len(_labels), initializer=initializer)
 | 
					            super().__init__(len(_labels), initializer=initializer)
 | 
				
			||||||
            self._register_labels(_labels)
 | 
					            self._register_labels(_labels)
 | 
				
			||||||
@@ -150,8 +151,8 @@ class LabeledComponents(Components):
 | 
				
			|||||||
        _precheck_initializer(initializer)
 | 
					        _precheck_initializer(initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Labels
 | 
					        # Labels
 | 
				
			||||||
        labels = get_labels_object(distribution)
 | 
					        labels_initializer = get_labels_initializer(distribution)
 | 
				
			||||||
        new_labels = labels.generate()
 | 
					        new_labels = labels_initializer.generate()
 | 
				
			||||||
        _labels = torch.cat([self._labels, new_labels])
 | 
					        _labels = torch.cat([self._labels, new_labels])
 | 
				
			||||||
        self._register_labels(_labels)
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -196,20 +197,24 @@ class ReasoningComponents(Components):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 reasonings=None,
 | 
					                 distribution=None,
 | 
				
			||||||
                 initializer=None,
 | 
					                 initializer=None,
 | 
				
			||||||
 | 
					                 reasoning_initializer=None,
 | 
				
			||||||
                 *,
 | 
					                 *,
 | 
				
			||||||
                 initialized_components=None):
 | 
					                 initialized_components=None):
 | 
				
			||||||
        if initialized_components is not None:
 | 
					        if initialized_components is not None:
 | 
				
			||||||
            components, reasonings = initialized_components
 | 
					            components, reasonings = initialized_components
 | 
				
			||||||
 | 
					 | 
				
			||||||
            super().__init__(initialized_components=components)
 | 
					            super().__init__(initialized_components=components)
 | 
				
			||||||
            self.register_parameter("_reasonings", reasonings)
 | 
					            self.register_parameter("_reasonings", reasonings)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self._initialize_reasonings(reasonings)
 | 
					            labels_initializer = get_labels_initializer(distribution)
 | 
				
			||||||
            super().__init__(len(self._reasonings), initializer=initializer)
 | 
					            self.initial_distribution = labels_initializer.distribution
 | 
				
			||||||
 | 
					            super().__init__(len(self.initial_distribution),
 | 
				
			||||||
 | 
					                             initializer=initializer)
 | 
				
			||||||
 | 
					            reasonings = reasoning_initializer.generate()
 | 
				
			||||||
 | 
					            self._register_reasonings(reasonings)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _initialize_reasonings(self, reasonings):
 | 
					    def _initialize_reasonings(self, reasoning_initializer):
 | 
				
			||||||
        if isinstance(reasonings, tuple):
 | 
					        if isinstance(reasonings, tuple):
 | 
				
			||||||
            num_classes, num_components = reasonings
 | 
					            num_classes, num_components = reasonings
 | 
				
			||||||
            reasonings = ZeroReasoningsInitializer(num_classes, num_components)
 | 
					            reasonings = ZeroReasoningsInitializer(num_classes, num_components)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@
 | 
				
			|||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
from collections.abc import Iterable
 | 
					from collections.abc import Iterable
 | 
				
			||||||
from itertools import chain
 | 
					from itertools import chain
 | 
				
			||||||
 | 
					from typing import List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.utils.data import DataLoader, Dataset
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
@@ -179,7 +180,7 @@ class UnequalLabelsInitializer(LabelsInitializer):
 | 
				
			|||||||
        self.clabels = clabels or range(len(self.dist))
 | 
					        self.clabels = clabels or range(len(self.dist))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def distribution(self):
 | 
					    def distribution(self) -> List:
 | 
				
			||||||
        return self.dist
 | 
					        return self.dist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate(self):
 | 
					    def generate(self):
 | 
				
			||||||
@@ -194,7 +195,7 @@ class EqualLabelsInitializer(LabelsInitializer):
 | 
				
			|||||||
        self.per_class = per_class
 | 
					        self.per_class = per_class
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def distribution(self):
 | 
					    def distribution(self) -> List:
 | 
				
			||||||
        return self.classes * [self.per_class]
 | 
					        return self.classes * [self.per_class]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate(self):
 | 
					    def generate(self):
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										86
									
								
								prototorch/components/labels.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								prototorch/components/labels.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch Labels."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.components.components import get_labels_initializer
 | 
				
			||||||
 | 
					from prototorch.components.initializers import (ClassAwareInitializer,
 | 
				
			||||||
 | 
					                                                ComponentsInitializer,
 | 
				
			||||||
 | 
					                                                EqualLabelsInitializer,
 | 
				
			||||||
 | 
					                                                UnequalLabelsInitializer)
 | 
				
			||||||
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_labels_initializer(distribution):
 | 
				
			||||||
 | 
					    if isinstance(distribution, dict):
 | 
				
			||||||
 | 
					        if "num_classes" in distribution.keys():
 | 
				
			||||||
 | 
					            labels = EqualLabelsInitializer(
 | 
				
			||||||
 | 
					                distribution["num_classes"],
 | 
				
			||||||
 | 
					                distribution["prototypes_per_class"])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            clabels = list(distribution.keys())
 | 
				
			||||||
 | 
					            dist = list(distribution.values())
 | 
				
			||||||
 | 
					            labels = UnequalLabelsInitializer(dist, clabels)
 | 
				
			||||||
 | 
					    elif isinstance(distribution, tuple):
 | 
				
			||||||
 | 
					        num_classes, prototypes_per_class = distribution
 | 
				
			||||||
 | 
					        labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
 | 
				
			||||||
 | 
					    elif isinstance(distribution, list):
 | 
				
			||||||
 | 
					        labels = UnequalLabelsInitializer(distribution)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        msg = f"`distribution` not understood." \
 | 
				
			||||||
 | 
					            f"You have provided: {distribution=}."
 | 
				
			||||||
 | 
					        raise ValueError(msg)
 | 
				
			||||||
 | 
					    return labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Labels(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 distribution=None,
 | 
				
			||||||
 | 
					                 initializer=None,
 | 
				
			||||||
 | 
					                 *,
 | 
				
			||||||
 | 
					                 initialized_labels=None):
 | 
				
			||||||
 | 
					        _labels = self.get_labels(distribution,
 | 
				
			||||||
 | 
					                                  initializer,
 | 
				
			||||||
 | 
					                                  initialized_labels=initialized_labels)
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_labels(self, labels):
 | 
				
			||||||
 | 
					        # self.register_buffer("_labels", labels)
 | 
				
			||||||
 | 
					        self.register_parameter("_labels",
 | 
				
			||||||
 | 
					                                Parameter(labels, requires_grad=False))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_labels(self,
 | 
				
			||||||
 | 
					                   distribution=None,
 | 
				
			||||||
 | 
					                   initializer=None,
 | 
				
			||||||
 | 
					                   *,
 | 
				
			||||||
 | 
					                   initialized_labels=None):
 | 
				
			||||||
 | 
					        if initialized_labels is not None:
 | 
				
			||||||
 | 
					            _labels = initialized_labels
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            labels_initializer = initializer or get_labels_initializer(
 | 
				
			||||||
 | 
					                distribution)
 | 
				
			||||||
 | 
					            self.initial_distribution = labels_initializer.distribution
 | 
				
			||||||
 | 
					            _labels = labels_initializer.generate()
 | 
				
			||||||
 | 
					        return _labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_labels(self,
 | 
				
			||||||
 | 
					                   distribution=None,
 | 
				
			||||||
 | 
					                   initializer=None,
 | 
				
			||||||
 | 
					                   *,
 | 
				
			||||||
 | 
					                   initialized_labels=None):
 | 
				
			||||||
 | 
					        new_labels = self.get_labels(distribution,
 | 
				
			||||||
 | 
					                                     initializer,
 | 
				
			||||||
 | 
					                                     initialized_labels=initialized_labels)
 | 
				
			||||||
 | 
					        _labels = torch.cat([self._labels, new_labels])
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_labels(self, indices=None):
 | 
				
			||||||
 | 
					        mask = torch.ones(len(self._labels, dtype=torch.bool))
 | 
				
			||||||
 | 
					        mask[indices] = False
 | 
				
			||||||
 | 
					        _labels = self._labels[mask]
 | 
				
			||||||
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def labels(self):
 | 
				
			||||||
 | 
					        return self._labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        return self._labels
 | 
				
			||||||
		Reference in New Issue
	
	Block a user