Add more experimental changes

The code gets very messy very quickly as soon as serialization features are
needed.
This commit is contained in:
Jensun Ravichandran
2021-04-21 21:59:19 +02:00
parent e5a62bd0fc
commit fadf8c25bf
2 changed files with 11 additions and 32 deletions

View File

@@ -25,6 +25,8 @@ class GLVQIris(GLVQ):
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-1)
parser.add_argument("--batch_size", type=int, default=150)
parser.add_argument("--input_dim", type=int, default=2)
parser.add_argument("--nclasses", type=int, default=3)
parser.add_argument("--prototypes_per_class", type=int, default=3)
parser.add_argument("--prototype_initializer",
type=str,
@@ -101,6 +103,7 @@ if __name__ == "__main__":
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
parser,
max_epochs=10,
callbacks=[
vis, # comment this line out to disable the visualization
],
@@ -109,12 +112,7 @@ if __name__ == "__main__":
# Initialize the model
args = parser.parse_args()
model = GLVQIris(
args,
input_dim=x_train.shape[1],
nclasses=3,
data=[x_train, y_train],
)
model = GLVQIris(args, data=[x_train, y_train])
# Model summary
print(model)