[BUGFIX] examples/lvqmln_iris.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 21:00:26 +02:00
parent a44219ee47
commit 7ec5528ade

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2): def __init__(self, input_size=4, hidden_size=10, latent_size=2):
@ -41,7 +40,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
distribution=[1, 2, 2], distribution=[3, 4, 5],
proto_lr=0.001, proto_lr=0.001,
bb_lr=0.001, bb_lr=0.001,
) )
@ -52,7 +51,10 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.LVQMLN( model = pt.models.LVQMLN(
hparams, hparams,
prototype_initializer=pt.components.SSI(train_ds, transform=backbone), prototypes_initializer=pt.initializers.SSCI(
train_ds,
transform=backbone,
),
backbone=backbone, backbone=backbone,
) )
@ -67,11 +69,21 @@ if __name__ == "__main__":
resolution=500, resolution=500,
axis_off=True, axis_off=True,
) )
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=20,
prune_quota_per_epoch=2,
frequency=10,
verbose=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
vis,
pruning,
],
) )
# Training loop # Training loop