feat(compatibility): Python3.6 compatibility
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
"""Abstract classes to be inherited by prototorch models."""
|
||||
|
||||
from typing import Final, final
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
@@ -43,7 +41,6 @@ class ProtoTorchBolt(pl.LightningModule):
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
@final
|
||||
def reconfigure_optimizers(self):
|
||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||
|
||||
@@ -175,7 +172,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization: Final = False
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
@@ -183,7 +180,6 @@ class NonGradientMixin(ProtoTorchMixin):
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
@final
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
Reference in New Issue
Block a user