fix: labels where on cpu in forward pass

This commit is contained in:
Alexander Engelsberger 2021-08-05 09:14:32 +02:00
parent f8ad1d83eb
commit 0af8cf36f8
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
4 changed files with 7 additions and 10 deletions

View File

@ -136,14 +136,14 @@ class SupervisedPrototypeModel(PrototypeModel):
def forward(self, x): def forward(self, x):
distances = self.compute_distances(x) distances = self.compute_distances(x)
plabels = self.proto_layer.labels _, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels) winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning) y_pred = torch.nn.functional.softmin(winning)
return y_pred return y_pred
def predict_from_distances(self, distances): def predict_from_distances(self, distances):
with torch.no_grad(): with torch.no_grad():
plabels = self.proto_layer.labels _, plabels = self.proto_layer()
y_pred = self.competition_layer(distances, plabels) y_pred = self.competition_layer(distances, plabels)
return y_pred return y_pred

View File

@ -55,7 +55,7 @@ class GLVQ(SupervisedPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.compute_distances(x) out = self.compute_distances(x)
plabels = self.proto_layer.labels _, plabels = self.proto_layer()
loss = self.loss(out, y, plabels) loss = self.loss(out, y, plabels)
return out, loss return out, loss

View File

@ -10,9 +10,7 @@ from .glvq import GLVQ
class LVQ1(NonGradientMixin, GLVQ): class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1.""" """Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos, plables = self.proto_layer()
plabels = self.proto_layer.labels
x, y = train_batch x, y = train_batch
dis = self.compute_distances(x) dis = self.compute_distances(x)
# TODO Vectorized implementation # TODO Vectorized implementation
@ -41,8 +39,7 @@ class LVQ1(NonGradientMixin, GLVQ):
class LVQ21(NonGradientMixin, GLVQ): class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1.""" """Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components protos, plabels = self.proto_layer()
plabels = self.proto_layer.labels
x, y = train_batch x, y = train_batch
dis = self.compute_distances(x) dis = self.compute_distances(x)

View File

@ -20,7 +20,7 @@ class CELVQ(GLVQ):
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.compute_distances(x) # [None, num_protos] out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.labels _, plabels = self.proto_layer()
winning = stratified_min_pooling(out, plabels) # [None, num_classes] winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning probs = -1.0 * winning
batch_loss = self.loss(probs, y.long()) batch_loss = self.loss(probs, y.long())
@ -54,7 +54,7 @@ class ProbabilisticLVQ(GLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.forward(x) out = self.forward(x)
plabels = self.proto_layer.labels _, plabels = self.proto_layer()
batch_loss = self.loss(out, y, plabels) batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum() loss = batch_loss.sum()
return loss return loss