fix: example test fixed
This commit is contained in:
@@ -51,8 +51,7 @@ if __name__ == "__main__":
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 2, 3],
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the backbone
|
||||
|
@@ -51,8 +51,7 @@ if __name__ == "__main__":
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 2, 3],
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
lr=0.01,
|
||||
input_dim=2,
|
||||
latent_dim=1,
|
||||
)
|
||||
|
@@ -55,7 +55,9 @@ if __name__ == "__main__":
|
||||
|
||||
# Setup trainer for GNG
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1000,
|
||||
accelerator="cpu",
|
||||
max_epochs=50 if args.fast_dev_run else
|
||||
1000, # 10 epochs fast dev run reproducible DIV error.
|
||||
callbacks=[
|
||||
es,
|
||||
],
|
||||
|
Reference in New Issue
Block a user