chore: upgrade pre commit
This commit is contained in:
@@ -205,7 +205,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
self.log("test_acc", accuracy)
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
class ProtoTorchMixin:
|
||||
"""All mixins are ProtoTorchMixins."""
|
||||
|
||||
|
||||
|
@@ -39,7 +39,7 @@ def ltangent_distance(x, y, omegas):
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
||||
omegas, omegas.permute([0, 2, 1]))
|
||||
projected_x = x @ p
|
||||
|
@@ -145,7 +145,7 @@ class SiameseGLVQ(GLVQ):
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
|
||||
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
||||
latent_x = self.backbone(x)
|
||||
|
||||
bb_grad = any([el.requires_grad for el in self.backbone.parameters()])
|
||||
|
Reference in New Issue
Block a user