feat(compatibility): Python3.6 compatibility
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
"""`models` plugin for the `prototorch` package."""
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (
|
||||
|
@@ -1,7 +1,5 @@
|
||||
"""Abstract classes to be inherited by prototorch models."""
|
||||
|
||||
from typing import Final, final
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
@@ -43,7 +41,6 @@ class ProtoTorchBolt(pl.LightningModule):
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
@final
|
||||
def reconfigure_optimizers(self):
|
||||
self.trainer.accelerator.setup_optimizers(self.trainer)
|
||||
|
||||
@@ -175,7 +172,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization: Final = False
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
raise NotImplementedError
|
||||
@@ -183,7 +180,6 @@ class NonGradientMixin(ProtoTorchMixin):
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
@final
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
@@ -55,7 +55,7 @@ class PruneLoserPrototypes(pl.Callback):
|
||||
distribution = dict(zip(labels.tolist(), counts.tolist()))
|
||||
if self.verbose:
|
||||
print(f"Re-adding pruned prototypes...")
|
||||
print(f"{distribution=}")
|
||||
print(f"distribution={distribution}")
|
||||
pl_module.add_prototypes(
|
||||
distribution=distribution,
|
||||
components_initializer=self.prototypes_initializer)
|
||||
|
@@ -112,7 +112,8 @@ class SiameseGLVQ(GLVQ):
|
||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||
lr=self.hparams.proto_lr)
|
||||
# Only add a backbone optimizer if backbone has trainable parameters
|
||||
if (bb_params := list(self.backbone.parameters())):
|
||||
bb_params = list(self.backbone.parameters())
|
||||
if (bb_params):
|
||||
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
|
||||
optimizers = [proto_opt, bb_opt]
|
||||
else:
|
||||
|
@@ -28,8 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ):
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
|
||||
print(f"{dis=}")
|
||||
print(f"{y=}")
|
||||
print(f"dis={dis}")
|
||||
print(f"y={y}")
|
||||
# Logging
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
|
@@ -251,8 +251,6 @@ class VisImgComp(Vis2DAbstract):
|
||||
size=self.embedding_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
# print(f"{data.shape=}")
|
||||
# print(f"{self.y_train[ind].shape=}")
|
||||
tb.add_embedding(data.view(len(ind), -1),
|
||||
label_img=data,
|
||||
global_step=None,
|
||||
|
Reference in New Issue
Block a user