[BUGFIX] examples/lvqmln_iris.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 21:00:26 +02:00
parent a44219ee47
commit 7ec5528ade

View File

@ -2,11 +2,10 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
import prototorch as pt
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@ -41,7 +40,7 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[1, 2, 2],
distribution=[3, 4, 5],
proto_lr=0.001,
bb_lr=0.001,
)
@ -52,7 +51,10 @@ if __name__ == "__main__":
# Initialize the model
model = pt.models.LVQMLN(
hparams,
prototype_initializer=pt.components.SSI(train_ds, transform=backbone),
prototypes_initializer=pt.initializers.SSCI(
train_ds,
transform=backbone,
),
backbone=backbone,
)
@ -67,11 +69,21 @@ if __name__ == "__main__":
resolution=500,
axis_off=True,
)
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=20,
prune_quota_per_epoch=2,
frequency=10,
verbose=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
callbacks=[
vis,
pruning,
],
)
# Training loop