[BUGFIX] examples/lvqmln_iris.py works again
				
					
				
			This commit is contained in:
		@@ -2,11 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Backbone(torch.nn.Module):
 | 
					class Backbone(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
					    def __init__(self, input_size=4, hidden_size=10, latent_size=2):
 | 
				
			||||||
@@ -41,7 +40,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # Hyperparameters
 | 
					    # Hyperparameters
 | 
				
			||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        distribution=[1, 2, 2],
 | 
					        distribution=[3, 4, 5],
 | 
				
			||||||
        proto_lr=0.001,
 | 
					        proto_lr=0.001,
 | 
				
			||||||
        bb_lr=0.001,
 | 
					        bb_lr=0.001,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@@ -52,7 +51,10 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.LVQMLN(
 | 
					    model = pt.models.LVQMLN(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        prototype_initializer=pt.components.SSI(train_ds, transform=backbone),
 | 
					        prototypes_initializer=pt.initializers.SSCI(
 | 
				
			||||||
 | 
					            train_ds,
 | 
				
			||||||
 | 
					            transform=backbone,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
        backbone=backbone,
 | 
					        backbone=backbone,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -67,11 +69,21 @@ if __name__ == "__main__":
 | 
				
			|||||||
        resolution=500,
 | 
					        resolution=500,
 | 
				
			||||||
        axis_off=True,
 | 
					        axis_off=True,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    pruning = pt.models.PruneLoserPrototypes(
 | 
				
			||||||
 | 
					        threshold=0.01,
 | 
				
			||||||
 | 
					        idle_epochs=20,
 | 
				
			||||||
 | 
					        prune_quota_per_epoch=2,
 | 
				
			||||||
 | 
					        frequency=10,
 | 
				
			||||||
 | 
					        verbose=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            vis,
 | 
				
			||||||
 | 
					            pruning,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Training loop
 | 
					    # Training loop
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user