diff --git a/setup.py b/setup.py index 8b22cfa..54b508d 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ INSTALL_REQUIRES = [ "torch>=1.3.1", "torchvision>=0.7.4", "numpy>=1.9.1", - "sklearn", + "scikit-learn", "matplotlib", ] DATASETS = [