feat(compatibility): Python3.6 compatibility

This commit is contained in:
Alexander Engelsberger
2021-08-30 17:15:40 +02:00
parent d7834e2cc0
commit 7b93cd4ad5
9 changed files with 33 additions and 17 deletions

View File

@@ -1,7 +1,5 @@
"""Abstract classes to be inherited by prototorch models."""
from typing import Final, final
import pytorch_lightning as pl
import torch
import torchmetrics
@@ -43,7 +41,6 @@ class ProtoTorchBolt(pl.LightningModule):
else:
return optimizer
@final
def reconfigure_optimizers(self):
self.trainer.accelerator.setup_optimizers(self.trainer)
@@ -175,7 +172,7 @@ class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization: Final = False
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
@@ -183,7 +180,6 @@ class NonGradientMixin(ProtoTorchMixin):
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
@final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)