Add siamese example using GMLVQ and Tecator
This commit is contained in:
		
							
								
								
									
										108
									
								
								examples/gmlvq_tecator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								examples/gmlvq_tecator.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,108 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch "siamese" GMLVQ example using Tecator."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.datasets.tecator import Tecator
 | 
				
			||||||
 | 
					from prototorch.functions.distances import sed
 | 
				
			||||||
 | 
					from prototorch.functions.normalizations import normalize_omegat_
 | 
				
			||||||
 | 
					from prototorch.modules import Prototypes1D
 | 
				
			||||||
 | 
					from prototorch.modules.losses import GLVQLoss
 | 
				
			||||||
 | 
					from prototorch.utils.colors import handles_and_colors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Prepare the dataset and dataloader
 | 
				
			||||||
 | 
					train_data = Tecator(root='./artifacts', train=True)
 | 
				
			||||||
 | 
					train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Model(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, **kwargs):
 | 
				
			||||||
 | 
					        """GMLVQ model as a siamese network."""
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        x, y = train_data.data, train_data.targets
 | 
				
			||||||
 | 
					        self.p1 = Prototypes1D(input_dim=100,
 | 
				
			||||||
 | 
					                               prototypes_per_class=2,
 | 
				
			||||||
 | 
					                               nclasses=2,
 | 
				
			||||||
 | 
					                               prototype_initializer='stratified_random',
 | 
				
			||||||
 | 
					                               data=[x, y])
 | 
				
			||||||
 | 
					        self.omega = torch.nn.Linear(in_features=100,
 | 
				
			||||||
 | 
					                                     out_features=100,
 | 
				
			||||||
 | 
					                                     bias=False)
 | 
				
			||||||
 | 
					        torch.nn.init.eye_(self.omega.weight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        protos = self.p1.prototypes
 | 
				
			||||||
 | 
					        plabels = self.p1.prototype_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Process `x` and `protos` through `omega`
 | 
				
			||||||
 | 
					        x_map = self.omega(x)
 | 
				
			||||||
 | 
					        protos_map = self.omega(protos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Compute distances and output
 | 
				
			||||||
 | 
					        dis = sed(x_map, protos_map)
 | 
				
			||||||
 | 
					        return dis, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Build the GLVQ model
 | 
				
			||||||
 | 
					model = Model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Print a summary of the model
 | 
				
			||||||
 | 
					print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Optimize using Adam optimizer from `torch.optim`
 | 
				
			||||||
 | 
					optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
 | 
				
			||||||
 | 
					scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
 | 
				
			||||||
 | 
					criterion = GLVQLoss(squashing='identity', beta=10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Training loop
 | 
				
			||||||
 | 
					for epoch in range(150):
 | 
				
			||||||
 | 
					    epoch_loss = 0.0  # zero-out epoch loss
 | 
				
			||||||
 | 
					    optimizer.zero_grad()  # zero-out gradients
 | 
				
			||||||
 | 
					    for xb, yb in train_loader:
 | 
				
			||||||
 | 
					        # Compute loss
 | 
				
			||||||
 | 
					        distances, plabels = model(xb)
 | 
				
			||||||
 | 
					        loss = criterion([distances, plabels], yb)
 | 
				
			||||||
 | 
					        epoch_loss += loss.item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Backprop
 | 
				
			||||||
 | 
					        loss.backward()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Normalize omega
 | 
				
			||||||
 | 
					        normalize_omegat_(model.omega.weight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Take a gradient descent step
 | 
				
			||||||
 | 
					    optimizer.step()
 | 
				
			||||||
 | 
					    scheduler.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    lr = optimizer.param_groups[0]['lr']
 | 
				
			||||||
 | 
					    print(f'Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Get the omega matrix form the model
 | 
				
			||||||
 | 
					omega = model.omega.weight.data.numpy().T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Visualize the lambda matrix
 | 
				
			||||||
 | 
					title = 'Lambda Matrix Visualization'
 | 
				
			||||||
 | 
					fig = plt.figure(title)
 | 
				
			||||||
 | 
					ax = fig.gca()
 | 
				
			||||||
 | 
					ax.set_title(title)
 | 
				
			||||||
 | 
					im = ax.imshow(omega.dot(omega.T), cmap='viridis')
 | 
				
			||||||
 | 
					plt.show()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Get the prototypes form the model
 | 
				
			||||||
 | 
					protos = model.p1.prototypes.data.numpy()
 | 
				
			||||||
 | 
					plabels = model.p1.prototype_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Visualize the prototypes
 | 
				
			||||||
 | 
					title = 'Tecator Prototypes'
 | 
				
			||||||
 | 
					fig = plt.figure(title)
 | 
				
			||||||
 | 
					ax = fig.gca()
 | 
				
			||||||
 | 
					ax.set_title(title)
 | 
				
			||||||
 | 
					ax.set_xlabel('Spectral frequencies')
 | 
				
			||||||
 | 
					ax.set_ylabel('Absorption')
 | 
				
			||||||
 | 
					clabels = ['Class 0 - Low fat', 'Class 1 - High fat']
 | 
				
			||||||
 | 
					handles, colors = handles_and_colors(clabels, marker='line')
 | 
				
			||||||
 | 
					for x, y in zip(protos, plabels):
 | 
				
			||||||
 | 
					    ax.plot(x, c=colors[int(y)])
 | 
				
			||||||
 | 
					ax.legend(handles, clabels)
 | 
				
			||||||
 | 
					plt.show()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user