Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
87fa3f0729 | ||
|
08db94d507 | ||
|
8ecf9948b2 | ||
|
c5f0b86114 | ||
|
7506614ada |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.5.2
|
current_version = 0.5.4
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.2.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -13,17 +13,17 @@ repos:
|
|||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
|
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v1.4
|
rev: v2.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
|
|
||||||
- repo: http://github.com/PyCQA/isort
|
- repo: http://github.com/PyCQA/isort
|
||||||
rev: 5.10.1
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v0.950
|
rev: v1.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
files: prototorch
|
files: prototorch
|
||||||
@@ -35,14 +35,14 @@ repos:
|
|||||||
- id: yapf
|
- id: yapf
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.9.0
|
rev: v1.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v2.32.1
|
rev: v3.3.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.5.2"
|
release = "0.5.4"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@@ -36,4 +36,4 @@ from .unsupervised import (
|
|||||||
)
|
)
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.5.2"
|
__version__ = "0.5.4"
|
||||||
|
@@ -205,7 +205,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
self.log("test_acc", accuracy)
|
self.log("test_acc", accuracy)
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin(object):
|
class ProtoTorchMixin:
|
||||||
"""All mixins are ProtoTorchMixins."""
|
"""All mixins are ProtoTorchMixins."""
|
||||||
|
|
||||||
|
|
||||||
|
@@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas):
|
|||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
:rtype: `torch.tensor`
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
||||||
omegas, omegas.permute([0, 2, 1]))
|
omegas, omegas.permute([0, 2, 1]))
|
||||||
projected_x = x @ p
|
projected_x = x @ p
|
||||||
|
@@ -145,7 +145,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
|
|
||||||
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
||||||
|
@@ -21,3 +21,7 @@ include_trailing_comma = True
|
|||||||
force_grid_wrap = 3
|
force_grid_wrap = 3
|
||||||
use_parentheses = True
|
use_parentheses = True
|
||||||
line_length = 79
|
line_length = 79
|
||||||
|
|
||||||
|
[mypy]
|
||||||
|
explicit_package_bases = True
|
||||||
|
namespace_packages = True
|
||||||
|
6
setup.py
6
setup.py
@@ -18,13 +18,13 @@ PLUGIN_NAME = "models"
|
|||||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||||
|
|
||||||
with open("README.md", "r") as fh:
|
with open("README.md") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"prototorch>=0.7.3",
|
"prototorch>=0.7.3",
|
||||||
"pytorch_lightning>=1.6.0",
|
"pytorch_lightning>=1.6.0",
|
||||||
"torchmetrics",
|
"torchmetrics<0.10.0",
|
||||||
"protobuf<3.20.0",
|
"protobuf<3.20.0",
|
||||||
]
|
]
|
||||||
CLI = [
|
CLI = [
|
||||||
@@ -55,7 +55,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||||
version="0.5.2",
|
version="0.5.4",
|
||||||
description="Pre-packaged prototype-based "
|
description="Pre-packaged prototype-based "
|
||||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
Reference in New Issue
Block a user