chore: upgrade pre commit
This commit is contained in:
		| @@ -205,7 +205,7 @@ class SupervisedPrototypeModel(PrototypeModel): | ||||
|         self.log("test_acc", accuracy) | ||||
|  | ||||
|  | ||||
| class ProtoTorchMixin(object): | ||||
| class ProtoTorchMixin: | ||||
|     """All mixins are ProtoTorchMixins.""" | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas): | ||||
|     :param `torch.tensor` omegas: Three dimensional matrix | ||||
|     :rtype: `torch.tensor` | ||||
|     """ | ||||
|     x, y = [arr.view(arr.size(0), -1) for arr in (x, y)] | ||||
|     x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) | ||||
|     p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm( | ||||
|         omegas, omegas.permute([0, 2, 1])) | ||||
|     projected_x = x @ p | ||||
|   | ||||
| @@ -145,7 +145,7 @@ class SiameseGLVQ(GLVQ): | ||||
|  | ||||
|     def compute_distances(self, x): | ||||
|         protos, _ = self.proto_layer() | ||||
|         x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)] | ||||
|         x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos)) | ||||
|         latent_x = self.backbone(x) | ||||
|  | ||||
|         bb_grad = any([el.requires_grad for el in self.backbone.parameters()]) | ||||
|   | ||||
| @@ -21,3 +21,7 @@ include_trailing_comma = True | ||||
| force_grid_wrap = 3 | ||||
| use_parentheses = True | ||||
| line_length = 79 | ||||
|  | ||||
| [mypy] | ||||
| explicit_package_bases = True | ||||
| namespace_packages = True | ||||
|   | ||||
							
								
								
									
										6
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								setup.py
									
									
									
									
									
								
							| @@ -91,9 +91,9 @@ setup( | ||||
|         "Topic :: Software Development :: Libraries", | ||||
|         "Topic :: Software Development :: Libraries :: Python Modules", | ||||
|     ], | ||||
|     entry_points={ | ||||
|         "prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}" | ||||
|     }, | ||||
|     #entry_points={ | ||||
|     #    "prototorch.plugins": f"{PLUGIN_NAME} = prototorch.{PLUGIN_NAME}" | ||||
|     #}, | ||||
|     packages=find_namespace_packages(include=["prototorch.*"]), | ||||
|     zip_safe=False, | ||||
| ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user