chore: fix minor errors and upgrade codebase
This commit is contained in:
parent
6ed1b9a832
commit
ee4cf583e3
@ -1,4 +1,4 @@
|
||||
"""ProtoTorch CBC example using 2D Iris data."""
|
||||
"""ProtoTorch GMLVQ example using Iris data."""
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -11,7 +11,7 @@ def squared_euclidean_distance(x, y):
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
@ -27,14 +27,14 @@ def euclidean_distance(x, y):
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
@ -54,7 +54,7 @@ def lpnorm_distance(x, y, p):
|
||||
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
@ -66,7 +66,7 @@ def omega_distance(x, y, omega):
|
||||
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
@ -80,7 +80,7 @@ def lomega_distance(x, y, omegas):
|
||||
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
|
@ -21,7 +21,7 @@ def cosine_similarity(x, y):
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
norm_x = x.pow(2).sum(1).sqrt()
|
||||
norm_y = y.pow(2).sum(1).sqrt()
|
||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||
|
@ -5,6 +5,7 @@ from typing import (
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -18,7 +19,7 @@ def generate_mesh(
|
||||
maxima: torch.TensorType,
|
||||
border: float = 1.0,
|
||||
resolution: int = 100,
|
||||
device: torch.device = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
# Apply Border
|
||||
ptp = maxima - minima
|
||||
@ -55,14 +56,15 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||
|
||||
|
||||
def distribution_from_list(list_dist: List[int],
|
||||
clabels: Iterable[int] = None):
|
||||
clabels: Optional[Iterable[int]] = None):
|
||||
clabels = clabels or list(range(len(list_dist)))
|
||||
distribution = dict(zip(clabels, list_dist))
|
||||
return distribution
|
||||
|
||||
|
||||
def parse_distribution(user_distribution,
|
||||
clabels: Iterable[int] = None) -> Dict[int, int]:
|
||||
def parse_distribution(
|
||||
user_distribution,
|
||||
clabels: Optional[Iterable[int]] = None) -> Dict[int, int]:
|
||||
"""Parse user-provided distribution.
|
||||
|
||||
Return a dictionary with integer keys that represent the class labels and
|
||||
|
2
setup.py
2
setup.py
@ -15,7 +15,7 @@ from setuptools import find_packages, setup
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
with open("README.md") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
INSTALL_REQUIRES = [
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""ProtoTorch datasets test suite"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
Loading…
Reference in New Issue
Block a user