chore: upgrade pre commit
This commit is contained in:
parent
7506614ada
commit
c5f0b86114
@ -205,7 +205,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
self.log("test_acc", accuracy)
|
self.log("test_acc", accuracy)
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchMixin(object):
|
class ProtoTorchMixin:
|
||||||
"""All mixins are ProtoTorchMixins."""
|
"""All mixins are ProtoTorchMixins."""
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas):
|
|||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
:rtype: `torch.tensor`
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
||||||
omegas, omegas.permute([0, 2, 1]))
|
omegas, omegas.permute([0, 2, 1]))
|
||||||
projected_x = x @ p
|
projected_x = x @ p
|
||||||
|
@ -145,7 +145,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
|
|
||||||
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
||||||
|
@ -21,3 +21,7 @@ include_trailing_comma = True
|
|||||||
force_grid_wrap = 3
|
force_grid_wrap = 3
|
||||||
use_parentheses = True
|
use_parentheses = True
|
||||||
line_length = 79
|
line_length = 79
|
||||||
|
|
||||||
|
[mypy]
|
||||||
|
explicit_package_bases = True
|
||||||
|
namespace_packages = True
|
||||||
|
6
setup.py
6
setup.py
@ -91,9 +91,9 @@ setup(
|
|||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
],
|
],
|
||||||
entry_points={
|
#entry_points={
|
||||||
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
|
# "prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
|
||||||
},
|
#},
|
||||||
packages=find_namespace_packages(include=["prototorch.*"]),
|
packages=find_namespace_packages(include=["prototorch.*"]),
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user