ci: add github actions (#16)
* chore: update pre-commit versions * ci: remove old configurations * ci: copy workflow from prototorch * ci: run precommit for all files * ci: add examples CPU test * ci(test): failing example test * ci: fix workflow definition * ci(test): repeat failing example test * ci: fix workflow definition * ci(test): repeat failing example test II * ci: fix test command * ci: cleanup example test * ci: remove travis badge
This commit is contained in:
committed by
GitHub
parent
62c5974a85
commit
1a17193b35
@@ -14,6 +14,7 @@ from ..nn.wrappers import LambdaLayer
|
||||
|
||||
class ProtoTorchBolt(pl.LightningModule):
|
||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
@@ -52,6 +53,7 @@ class ProtoTorchBolt(pl.LightningModule):
|
||||
|
||||
|
||||
class PrototypeModel(ProtoTorchBolt):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -81,6 +83,7 @@ class PrototypeModel(ProtoTorchBolt):
|
||||
|
||||
|
||||
class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -103,6 +106,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
|
||||
class SupervisedPrototypeModel(PrototypeModel):
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
@@ -178,6 +182,7 @@ class ProtoTorchMixin(object):
|
||||
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
@@ -188,6 +193,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
Reference in New Issue
Block a user