diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index f6886c6..f990145 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -205,7 +205,7 @@ class SupervisedPrototypeModel(PrototypeModel): self.log("test_acc", accuracy) -class ProtoTorchMixin(object): +class ProtoTorchMixin: """All mixins are ProtoTorchMixins.""" diff --git a/prototorch/models/extras.py b/prototorch/models/extras.py index 1dbb082..f80029e 100644 --- a/prototorch/models/extras.py +++ b/prototorch/models/extras.py @@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas): :param `torch.tensor` omegas: Three dimensional matrix :rtype: `torch.tensor` """ - x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] + x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm( omegas, omegas.permute([0, 2, 1])) projected_x = x @ p diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 81ddaef..cab60d5 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -145,7 +145,7 @@ class SiameseGLVQ(GLVQ): def compute_distances(self, x): protos, _ = self.proto_layer() - x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] + x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos)) latent_x = self.backbone(x) bb_grad = any([el.requires_grad for el in self.backbone.parameters()]) diff --git a/setup.cfg b/setup.cfg index e3c8135..ebbcd59 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,3 +21,7 @@ include_trailing_comma = True force_grid_wrap = 3 use_parentheses = True line_length = 79 + +[mypy] +explicit_package_bases = True +namespace_packages = True diff --git a/setup.py b/setup.py index 41702a6..c999a58 100644 --- a/setup.py +++ b/setup.py @@ -91,9 +91,9 @@ setup( "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], - entry_points={ - "prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}" - }, + #entry_points={ + # "prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}" + #}, packages=find_namespace_packages(include=["prototorch.*"]), zip_safe=False, )