chore: upgrade pre commit

This commit is contained in:
Alexander Engelsberger 2023-03-02 17:23:41 +00:00
parent 7506614ada
commit c5f0b86114
No known key found for this signature in database
GPG Key ID: 1FF2A4E5A222AFBF
5 changed files with 10 additions and 6 deletions

View File

@ -205,7 +205,7 @@ class SupervisedPrototypeModel(PrototypeModel):
self.log("test_acc", accuracy) self.log("test_acc", accuracy)
class ProtoTorchMixin(object): class ProtoTorchMixin:
"""All mixins are ProtoTorchMixins.""" """All mixins are ProtoTorchMixins."""

View File

@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas):
:param `torch.tensor` omegas: Three dimensional matrix :param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor` :rtype: `torch.tensor`
""" """
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm( p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1])) omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p projected_x = x @ p

View File

@ -145,7 +145,7 @@ class SiameseGLVQ(GLVQ):
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
latent_x = self.backbone(x) latent_x = self.backbone(x)
bb_grad = any([el.requires_grad for el in self.backbone.parameters()]) bb_grad = any([el.requires_grad for el in self.backbone.parameters()])

View File

@ -21,3 +21,7 @@ include_trailing_comma = True
force_grid_wrap = 3 force_grid_wrap = 3
use_parentheses = True use_parentheses = True
line_length = 79 line_length = 79
[mypy]
explicit_package_bases = True
namespace_packages = True

View File

@ -91,9 +91,9 @@ setup(
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Python Modules",
], ],
entry_points={ #entry_points={
"prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}" # "prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}"
}, #},
packages=find_namespace_packages(include=["prototorch.*"]), packages=find_namespace_packages(include=["prototorch.*"]),
zip_safe=False, zip_safe=False,
) )