[refactor] DRY Probabilistic models

This commit is contained in:
Alexander Engelsberger
2021-05-28 17:13:06 +02:00
parent dade502686
commit e3392ee952
2 changed files with 31 additions and 99 deletions

View File

@@ -33,11 +33,12 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.probabilistic.RSLVQ(
model = pt.models.probabilistic.LikelihoodRatioLVQ(
#model = pt.models.probabilistic.RSLVQ(
hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SSI(train_ds, noise=2),
#prototype_initializer=pt.components.UniformInitializer(2),
#prototype_initializer=pt.components.SSI(train_ds, noise=2),
prototype_initializer=pt.components.UniformInitializer(2),
)
# Callbacks