[BUGFIX] examples/lvqmln_iris.py works again
				
					
				
			This commit is contained in:
		@@ -2,11 +2,10 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Backbone(torch.nn.Module):
 | 
			
		||||
    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
			
		||||
@@ -41,7 +40,7 @@ if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    # Hyperparameters
 | 
			
		||||
    hparams = dict(
 | 
			
		||||
        distribution=[1, 2, 2],
 | 
			
		||||
        distribution=[3, 4, 5],
 | 
			
		||||
        proto_lr=0.001,
 | 
			
		||||
        bb_lr=0.001,
 | 
			
		||||
    )
 | 
			
		||||
@@ -52,7 +51,10 @@ if __name__ == "__main__":
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.LVQMLN(
 | 
			
		||||
        hparams,
 | 
			
		||||
        prototype_initializer=pt.components.SSI(train_ds, transform=backbone),
 | 
			
		||||
        prototypes_initializer=pt.initializers.SSCI(
 | 
			
		||||
            train_ds,
 | 
			
		||||
            transform=backbone,
 | 
			
		||||
        ),
 | 
			
		||||
        backbone=backbone,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@@ -67,11 +69,21 @@ if __name__ == "__main__":
 | 
			
		||||
        resolution=500,
 | 
			
		||||
        axis_off=True,
 | 
			
		||||
    )
 | 
			
		||||
    pruning = pt.models.PruneLoserPrototypes(
 | 
			
		||||
        threshold=0.01,
 | 
			
		||||
        idle_epochs=20,
 | 
			
		||||
        prune_quota_per_epoch=2,
 | 
			
		||||
        frequency=10,
 | 
			
		||||
        verbose=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[vis],
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            pruning,
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user