Compare commits
	
		
			38 Commits
		
	
	
		
			v0.2.0
			...
			feature/je
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					9e64f00579 | ||
| 
						 | 
					d54fc5dad1 | ||
| 
						 | 
					c203e13604 | ||
| 
						 | 
					4923ab8ef1 | ||
| 
						 | 
					597a7afa67 | ||
| 
						 | 
					7020ac587b | ||
| 
						 | 
					872bad9b86 | ||
| 
						 | 
					8693ecbfb6 | ||
| 
						 | 
					6370ff61a6 | ||
| 
						 | 
					328e789c86 | ||
| 
						 | 
					5bc8c57490 | ||
| 
						 | 
					75ab2897c4 | ||
| 
						 | 
					f4519eb430 | ||
| 
						 | 
					8ed385f6d2 | ||
| 
						 | 
					c88bf9c6b7 | ||
| 
						 | 
					26cc0690ef | ||
| 
						 | 
					84f90d026d | ||
| 
						 | 
					df99f1bc18 | ||
| 
						 | 
					76c147b57a | ||
| 
						 | 
					6aa8a59a57 | ||
| 
						 | 
					2da3a8f226 | ||
| 
						 | 
					67fff5df3c | ||
| 
						 | 
					7d4a041df2 | ||
| 
						 | 
					04c51c00c6 | ||
| 
						 | 
					62185b38cf | ||
| 
						 | 
					7b93cd4ad5 | ||
| 
						 | 
					d7834e2cc0 | ||
| 
						 | 
					0af8cf36f8 | ||
| 
						 | 
					f8ad1d83eb | ||
| 
						 | 
					23a3683860 | ||
| 
						 | 
					4be9fb81eb | ||
| 
						 | 
					9d38123114 | ||
| 
						 | 
					0f9f24e36a | ||
| 
						 | 
					09e3ef1d0e | ||
| 
						 | 
					7b9b767113 | ||
| 
						 | 
					f56ec44afe | ||
| 
						 | 
					67a20124e8 | ||
| 
						 | 
					72af03b991 | 
@@ -1,5 +1,5 @@
 | 
			
		||||
[bumpversion]
 | 
			
		||||
current_version = 0.2.0
 | 
			
		||||
current_version = 0.3.0
 | 
			
		||||
commit = True
 | 
			
		||||
tag = True
 | 
			
		||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										5
									
								
								.ci/gpu.Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.ci/gpu.Dockerfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
FROM nvcr.io/nvidia/pytorch:21.10-py3
 | 
			
		||||
 | 
			
		||||
RUN adduser --uid 1000 jenkins
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
							
								
								
									
										5
									
								
								.ci/python310.Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.ci/python310.Dockerfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
FROM python:3.9
 | 
			
		||||
 | 
			
		||||
RUN adduser --uid 1000 jenkins
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
							
								
								
									
										5
									
								
								.ci/python36.Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.ci/python36.Dockerfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
FROM python:3.6
 | 
			
		||||
 | 
			
		||||
RUN adduser --uid 1000 jenkins
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
							
								
								
									
										5
									
								
								.ci/python37.Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.ci/python37.Dockerfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
FROM python:3.7
 | 
			
		||||
 | 
			
		||||
RUN adduser --uid 1000 jenkins
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
							
								
								
									
										5
									
								
								.ci/python38.Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.ci/python38.Dockerfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
FROM python:3.8
 | 
			
		||||
 | 
			
		||||
RUN adduser --uid 1000 jenkins
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
							
								
								
									
										5
									
								
								.ci/python39.Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.ci/python39.Dockerfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
FROM python:3.9
 | 
			
		||||
 | 
			
		||||
RUN adduser --uid 1000 jenkins
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
							
								
								
									
										38
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
---
 | 
			
		||||
name: Bug report
 | 
			
		||||
about: Create a report to help us improve
 | 
			
		||||
title: ''
 | 
			
		||||
labels: ''
 | 
			
		||||
assignees: ''
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
**Describe the bug**
 | 
			
		||||
A clear and concise description of what the bug is.
 | 
			
		||||
 | 
			
		||||
**Steps to reproduce the behavior**
 | 
			
		||||
1. ...
 | 
			
		||||
2. Run script '...' or this snippet:
 | 
			
		||||
```python
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
...
 | 
			
		||||
```
 | 
			
		||||
3. See errors
 | 
			
		||||
 | 
			
		||||
**Expected behavior**
 | 
			
		||||
A clear and concise description of what you expected to happen.
 | 
			
		||||
 | 
			
		||||
**Observed behavior**
 | 
			
		||||
A clear and concise description of what actually happened.
 | 
			
		||||
 | 
			
		||||
**Screenshots**
 | 
			
		||||
If applicable, add screenshots to help explain your problem.
 | 
			
		||||
 | 
			
		||||
**System and version information**
 | 
			
		||||
- OS: [e.g. Ubuntu 20.10]
 | 
			
		||||
- ProtoTorch Version: [e.g. 0.4.0]
 | 
			
		||||
- Python Version: [e.g. 3.9.5]
 | 
			
		||||
 | 
			
		||||
**Additional context**
 | 
			
		||||
Add any other context about the problem here.
 | 
			
		||||
							
								
								
									
										20
									
								
								.github/ISSUE_TEMPLATE/feature_request.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								.github/ISSUE_TEMPLATE/feature_request.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
---
 | 
			
		||||
name: Feature request
 | 
			
		||||
about: Suggest an idea for this project
 | 
			
		||||
title: ''
 | 
			
		||||
labels: ''
 | 
			
		||||
assignees: ''
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
**Is your feature request related to a problem? Please describe.**
 | 
			
		||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
 | 
			
		||||
 | 
			
		||||
**Describe the solution you'd like**
 | 
			
		||||
A clear and concise description of what you want to happen.
 | 
			
		||||
 | 
			
		||||
**Describe alternatives you've considered**
 | 
			
		||||
A clear and concise description of any alternative solutions or features you've considered.
 | 
			
		||||
 | 
			
		||||
**Additional context**
 | 
			
		||||
Add any other context or screenshots about the feature request here.
 | 
			
		||||
							
								
								
									
										25
									
								
								.travis.yml
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								.travis.yml
									
									
									
									
									
								
							@@ -1,25 +0,0 @@
 | 
			
		||||
dist: bionic
 | 
			
		||||
sudo: false
 | 
			
		||||
language: python
 | 
			
		||||
python: 3.9
 | 
			
		||||
cache:
 | 
			
		||||
  directories:
 | 
			
		||||
  - "$HOME/.cache/pip"
 | 
			
		||||
  - "./tests/artifacts"
 | 
			
		||||
  - "$HOME/datasets"
 | 
			
		||||
install:
 | 
			
		||||
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
 | 
			
		||||
- pip install .[all] --progress-bar off
 | 
			
		||||
script:
 | 
			
		||||
- coverage run -m pytest
 | 
			
		||||
- ./tests/test_examples.sh examples/
 | 
			
		||||
after_success:
 | 
			
		||||
- bash <(curl -s https://codecov.io/bash)
 | 
			
		||||
deploy:
 | 
			
		||||
  provider: pypi
 | 
			
		||||
  username: __token__
 | 
			
		||||
  password:
 | 
			
		||||
    secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
 | 
			
		||||
  on:
 | 
			
		||||
    tags: true
 | 
			
		||||
    skip_existing: true
 | 
			
		||||
							
								
								
									
										118
									
								
								Jenkinsfile
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								Jenkinsfile
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,118 @@
 | 
			
		||||
pipeline {
 | 
			
		||||
  agent none
 | 
			
		||||
  stages {
 | 
			
		||||
    stage('Unit Tests') {
 | 
			
		||||
          agent {
 | 
			
		||||
            dockerfile {
 | 
			
		||||
              filename 'python310.Dockerfile'
 | 
			
		||||
              dir '.ci'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
          }
 | 
			
		||||
          steps {
 | 
			
		||||
            sh 'pip install pip --upgrade --progress-bar off'
 | 
			
		||||
            sh 'pip install .[all] --progress-bar off'
 | 
			
		||||
            sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
 | 
			
		||||
            cobertura coberturaReportFile: 'reports/coverage.xml'
 | 
			
		||||
            junit 'reports/**/*.xml'
 | 
			
		||||
          }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    stage('CPU Examples') {
 | 
			
		||||
      parallel {
 | 
			
		||||
        stage('3.10') {
 | 
			
		||||
          agent {
 | 
			
		||||
            dockerfile {
 | 
			
		||||
              filename 'python310.Dockerfile'
 | 
			
		||||
              dir '.ci'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
          }
 | 
			
		||||
          steps {
 | 
			
		||||
            sh 'pip install pip --upgrade --progress-bar off'
 | 
			
		||||
            sh 'pip install .[all] --progress-bar off'
 | 
			
		||||
            sh './tests/test_examples.sh examples'
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        stage('3.9') {
 | 
			
		||||
          agent {
 | 
			
		||||
            dockerfile {
 | 
			
		||||
              filename 'python39.Dockerfile'
 | 
			
		||||
              dir '.ci'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
          }
 | 
			
		||||
          steps {
 | 
			
		||||
            sh 'pip install pip --upgrade --progress-bar off'
 | 
			
		||||
            sh 'pip install .[all] --progress-bar off'
 | 
			
		||||
            sh './tests/test_examples.sh examples'
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        stage('3.8') {
 | 
			
		||||
          agent {
 | 
			
		||||
            dockerfile {
 | 
			
		||||
              filename 'python38.Dockerfile'
 | 
			
		||||
              dir '.ci'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
          }
 | 
			
		||||
          steps {
 | 
			
		||||
            sh 'pip install pip --upgrade --progress-bar off'
 | 
			
		||||
            sh 'pip install .[all] --progress-bar off'
 | 
			
		||||
            sh './tests/test_examples.sh examples'
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        stage('3.7') {
 | 
			
		||||
          agent {
 | 
			
		||||
            dockerfile {
 | 
			
		||||
              filename 'python37.Dockerfile'
 | 
			
		||||
              dir '.ci'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
          }
 | 
			
		||||
          steps {
 | 
			
		||||
            sh 'pip install pip --upgrade --progress-bar off'
 | 
			
		||||
            sh 'pip install .[all] --progress-bar off'
 | 
			
		||||
            sh './tests/test_examples.sh examples'
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        stage('3.6') {
 | 
			
		||||
          agent {
 | 
			
		||||
            dockerfile {
 | 
			
		||||
              filename 'python36.Dockerfile'
 | 
			
		||||
              dir '.ci'
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
          }
 | 
			
		||||
          steps {
 | 
			
		||||
            sh 'pip install pip --upgrade --progress-bar off'
 | 
			
		||||
            sh 'pip install .[all] --progress-bar off'
 | 
			
		||||
            sh './tests/test_examples.sh examples'
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    stage('GPU Examples') {
 | 
			
		||||
      agent {
 | 
			
		||||
        dockerfile {
 | 
			
		||||
          filename 'gpu.Dockerfile'
 | 
			
		||||
          dir '.ci'
 | 
			
		||||
          args '--gpus 1'
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
      }
 | 
			
		||||
      steps {
 | 
			
		||||
        sh 'pip install -U pip --progress-bar off'
 | 
			
		||||
        sh 'pip install .[all] --progress-bar off'
 | 
			
		||||
        sh './tests/test_examples.sh examples --gpu'
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@@ -36,6 +36,7 @@ be available for use in your Python environment as `prototorch.models`.
 | 
			
		||||
- Soft Learning Vector Quantization (SLVQ)
 | 
			
		||||
- Robust Soft Learning Vector Quantization (RSLVQ)
 | 
			
		||||
- Probabilistic Learning Vector Quantization (PLVQ)
 | 
			
		||||
- Median-LVQ
 | 
			
		||||
 | 
			
		||||
### Other
 | 
			
		||||
 | 
			
		||||
@@ -51,7 +52,6 @@ be available for use in your Python environment as `prototorch.models`.
 | 
			
		||||
 | 
			
		||||
## Planned models
 | 
			
		||||
 | 
			
		||||
- Median-LVQ
 | 
			
		||||
- Generalized Tangent Learning Vector Quantization (GTLVQ)
 | 
			
		||||
- Self-Incremental Learning Vector Quantization (SILVQ)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										44
									
								
								deprecated.travis.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								deprecated.travis.yml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
dist: bionic
 | 
			
		||||
sudo: false
 | 
			
		||||
language: python
 | 
			
		||||
python:
 | 
			
		||||
  - 3.9
 | 
			
		||||
  - 3.8
 | 
			
		||||
  - 3.7
 | 
			
		||||
  - 3.6
 | 
			
		||||
cache:
 | 
			
		||||
  directories:
 | 
			
		||||
  - "$HOME/.cache/pip"
 | 
			
		||||
  - "./tests/artifacts"
 | 
			
		||||
  - "$HOME/datasets"
 | 
			
		||||
install:
 | 
			
		||||
- pip install git+git://github.com/si-cim/prototorch@dev --progress-bar off
 | 
			
		||||
- pip install .[all] --progress-bar off
 | 
			
		||||
script:
 | 
			
		||||
- coverage run -m pytest
 | 
			
		||||
- ./tests/test_examples.sh examples/
 | 
			
		||||
after_success:
 | 
			
		||||
- bash <(curl -s https://codecov.io/bash)
 | 
			
		||||
 | 
			
		||||
# Publish on PyPI
 | 
			
		||||
jobs:
 | 
			
		||||
  include:
 | 
			
		||||
    - stage: build
 | 
			
		||||
      python: 3.9
 | 
			
		||||
      script: echo "Starting Pypi build"
 | 
			
		||||
      deploy:
 | 
			
		||||
        provider: pypi
 | 
			
		||||
        username: __token__
 | 
			
		||||
        distributions: "sdist bdist_wheel"
 | 
			
		||||
        password:
 | 
			
		||||
          secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
 | 
			
		||||
        on:
 | 
			
		||||
          tags: true
 | 
			
		||||
          skip_existing: true
 | 
			
		||||
 | 
			
		||||
# The password is encrypted with:
 | 
			
		||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
 | 
			
		||||
# See https://docs.travis-ci.com/user/deployment/pypi and
 | 
			
		||||
# https://github.com/travis-ci/travis.rb#installation
 | 
			
		||||
# for more details
 | 
			
		||||
# Note: The encrypt command does not work well in ZSH.
 | 
			
		||||
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
 | 
			
		||||
 | 
			
		||||
# The full version, including alpha/beta/rc tags
 | 
			
		||||
#
 | 
			
		||||
release = "0.2.0"
 | 
			
		||||
release = "0.3.0"
 | 
			
		||||
 | 
			
		||||
# -- General configuration ---------------------------------------------------
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@@ -1,12 +1,11 @@
 | 
			
		||||
"""GMLVQ example using the MNIST dataset."""
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch_lightning.utilities.cli import LightningCLI
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import torch
 | 
			
		||||
from prototorch.models import ImageGMLVQ
 | 
			
		||||
from prototorch.models.abstract import PrototypeModel
 | 
			
		||||
from prototorch.models.data import MNISTDataModule
 | 
			
		||||
from pytorch_lightning.utilities.cli import LightningCLI
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExperimentClass(ImageGMLVQ):
 | 
			
		||||
 
 | 
			
		||||
@@ -66,7 +66,7 @@ if __name__ == "__main__":
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            # es, # FIXME
 | 
			
		||||
            es,
 | 
			
		||||
            pruning,
 | 
			
		||||
        ],
 | 
			
		||||
        terminate_on_nan=True,
 | 
			
		||||
 
 | 
			
		||||
@@ -2,12 +2,11 @@
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from sklearn.datasets import load_iris
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										52
									
								
								examples/median_lvq_iris.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								examples/median_lvq_iris.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
"""Median-LVQ example using the Iris dataset."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Command-line arguments
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser = pl.Trainer.add_argparse_args(parser)
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Dataset
 | 
			
		||||
    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
			
		||||
 | 
			
		||||
    # Dataloaders
 | 
			
		||||
    train_loader = torch.utils.data.DataLoader(
 | 
			
		||||
        train_ds,
 | 
			
		||||
        batch_size=len(train_ds),  # MedianLVQ cannot handle mini-batches
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Initialize the model
 | 
			
		||||
    model = pt.models.MedianLVQ(
 | 
			
		||||
        hparams=dict(distribution=(3, 2), lr=0.01),
 | 
			
		||||
        prototypes_initializer=pt.initializers.SSCI(train_ds),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Compute intermediate input and output sizes
 | 
			
		||||
    model.example_input_array = torch.zeros(4, 2)
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_acc",
 | 
			
		||||
        min_delta=0.01,
 | 
			
		||||
        patience=5,
 | 
			
		||||
        mode="max",
 | 
			
		||||
        verbose=True,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[vis, es],
 | 
			
		||||
        weights_summary="full",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training loop
 | 
			
		||||
    trainer.fit(model, train_loader)
 | 
			
		||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    # Setup trainer for GNG
 | 
			
		||||
    trainer = pl.Trainer(
 | 
			
		||||
        max_epochs=200,
 | 
			
		||||
        max_epochs=100,
 | 
			
		||||
        callbacks=[es],
 | 
			
		||||
        weights_summary=None,
 | 
			
		||||
    )
 | 
			
		||||
@@ -71,11 +71,30 @@ if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    # Callbacks
 | 
			
		||||
    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
			
		||||
    pruning = pt.models.PruneLoserPrototypes(
 | 
			
		||||
        threshold=0.02,
 | 
			
		||||
        idle_epochs=2,
 | 
			
		||||
        prune_quota_per_epoch=5,
 | 
			
		||||
        frequency=1,
 | 
			
		||||
        verbose=True,
 | 
			
		||||
    )
 | 
			
		||||
    es = pl.callbacks.EarlyStopping(
 | 
			
		||||
        monitor="train_loss",
 | 
			
		||||
        min_delta=0.001,
 | 
			
		||||
        patience=10,
 | 
			
		||||
        mode="min",
 | 
			
		||||
        verbose=True,
 | 
			
		||||
        check_on_train_epoch_end=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Setup trainer
 | 
			
		||||
    trainer = pl.Trainer.from_argparse_args(
 | 
			
		||||
        args,
 | 
			
		||||
        callbacks=[vis],
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            pruning,
 | 
			
		||||
            es,
 | 
			
		||||
        ],
 | 
			
		||||
        weights_summary="full",
 | 
			
		||||
        accelerator="ddp",
 | 
			
		||||
    )
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,5 @@
 | 
			
		||||
"""`models` plugin for the `prototorch` package."""
 | 
			
		||||
 | 
			
		||||
from importlib.metadata import PackageNotFoundError, version
 | 
			
		||||
 | 
			
		||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
 | 
			
		||||
from .cbc import CBC, ImageCBC
 | 
			
		||||
from .glvq import (
 | 
			
		||||
@@ -23,4 +21,4 @@ from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
 | 
			
		||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
 | 
			
		||||
from .vis import *
 | 
			
		||||
 | 
			
		||||
__version__ = "0.2.0"
 | 
			
		||||
__version__ = "0.3.0"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,5 @@
 | 
			
		||||
"""Abstract classes to be inherited by prototorch models."""
 | 
			
		||||
 | 
			
		||||
from typing import Final, final
 | 
			
		||||
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
import torchmetrics
 | 
			
		||||
@@ -14,20 +12,8 @@ from ..core.pooling import stratified_min_pooling
 | 
			
		||||
from ..nn.wrappers import LambdaLayer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProtoTorchMixin(object):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProtoTorchBolt(pl.LightningModule):
 | 
			
		||||
    """All ProtoTorch models are ProtoTorch Bolts."""
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        surep = super().__repr__()
 | 
			
		||||
        indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
 | 
			
		||||
        wrapped = f"ProtoTorch Bolt(\n{indented})"
 | 
			
		||||
        return wrapped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
    def __init__(self, hparams, **kwargs):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
@@ -42,6 +28,33 @@ class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
        self.lr_scheduler = kwargs.get("lr_scheduler", None)
 | 
			
		||||
        self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
 | 
			
		||||
        if self.lr_scheduler is not None:
 | 
			
		||||
            scheduler = self.lr_scheduler(optimizer,
 | 
			
		||||
                                          **self.lr_scheduler_kwargs)
 | 
			
		||||
            sch = {
 | 
			
		||||
                "scheduler": scheduler,
 | 
			
		||||
                "interval": "step",
 | 
			
		||||
            }  # called after each training step
 | 
			
		||||
            return [optimizer], [sch]
 | 
			
		||||
        else:
 | 
			
		||||
            return optimizer
 | 
			
		||||
 | 
			
		||||
    def reconfigure_optimizers(self):
 | 
			
		||||
        self.trainer.accelerator.setup_optimizers(self.trainer)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        surep = super().__repr__()
 | 
			
		||||
        indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
 | 
			
		||||
        wrapped = f"ProtoTorch Bolt(\n{indented})"
 | 
			
		||||
        return wrapped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
    def __init__(self, hparams, **kwargs):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        distance_fn = kwargs.get("distance_fn", euclidean_distance)
 | 
			
		||||
        self.distance_layer = LambdaLayer(distance_fn)
 | 
			
		||||
 | 
			
		||||
@@ -58,23 +71,6 @@ class PrototypeModel(ProtoTorchBolt):
 | 
			
		||||
        """Only an alias for the prototypes."""
 | 
			
		||||
        return self.prototypes
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
 | 
			
		||||
        if self.lr_scheduler is not None:
 | 
			
		||||
            scheduler = self.lr_scheduler(optimizer,
 | 
			
		||||
                                          **self.lr_scheduler_kwargs)
 | 
			
		||||
            sch = {
 | 
			
		||||
                "scheduler": scheduler,
 | 
			
		||||
                "interval": "step",
 | 
			
		||||
            }  # called after each training step
 | 
			
		||||
            return [optimizer], [sch]
 | 
			
		||||
        else:
 | 
			
		||||
            return optimizer
 | 
			
		||||
 | 
			
		||||
    @final
 | 
			
		||||
    def reconfigure_optimizers(self):
 | 
			
		||||
        self.trainer.accelerator_backend.setup_optimizers(self.trainer)
 | 
			
		||||
 | 
			
		||||
    def add_prototypes(self, *args, **kwargs):
 | 
			
		||||
        self.proto_layer.add_components(*args, **kwargs)
 | 
			
		||||
        self.reconfigure_optimizers()
 | 
			
		||||
@@ -97,7 +93,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def compute_distances(self, x):
 | 
			
		||||
        protos = self.proto_layer()
 | 
			
		||||
        protos = self.proto_layer().type_as(x)
 | 
			
		||||
        distances = self.distance_layer(x, protos)
 | 
			
		||||
        return distances
 | 
			
		||||
 | 
			
		||||
@@ -137,14 +133,14 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        distances = self.compute_distances(x)
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        _, plabels = self.proto_layer()
 | 
			
		||||
        winning = stratified_min_pooling(distances, plabels)
 | 
			
		||||
        y_pred = torch.nn.functional.softmin(winning)
 | 
			
		||||
        return y_pred
 | 
			
		||||
 | 
			
		||||
    def predict_from_distances(self, distances):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            plabels = self.proto_layer.labels
 | 
			
		||||
            _, plabels = self.proto_layer()
 | 
			
		||||
            y_pred = self.competition_layer(distances, plabels)
 | 
			
		||||
        return y_pred
 | 
			
		||||
 | 
			
		||||
@@ -167,11 +163,16 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
			
		||||
                 logger=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProtoTorchMixin(object):
 | 
			
		||||
    """All mixins are ProtoTorchMixins."""
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NonGradientMixin(ProtoTorchMixin):
 | 
			
		||||
    """Mixin for custom non-gradient optimization."""
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.automatic_optimization: Final = False
 | 
			
		||||
        self.automatic_optimization = False
 | 
			
		||||
 | 
			
		||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
@@ -179,7 +180,6 @@ class NonGradientMixin(ProtoTorchMixin):
 | 
			
		||||
 | 
			
		||||
class ImagePrototypesMixin(ProtoTorchMixin):
 | 
			
		||||
    """Mixin for models with image prototypes."""
 | 
			
		||||
    @final
 | 
			
		||||
    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
			
		||||
        """Constrain the components to the range [0, 1] by clamping after updates."""
 | 
			
		||||
        self.proto_layer.components.data.clamp_(0.0, 1.0)
 | 
			
		||||
 
 | 
			
		||||
@@ -55,7 +55,7 @@ class PruneLoserPrototypes(pl.Callback):
 | 
			
		||||
                distribution = dict(zip(labels.tolist(), counts.tolist()))
 | 
			
		||||
                if self.verbose:
 | 
			
		||||
                    print(f"Re-adding pruned prototypes...")
 | 
			
		||||
                    print(f"{distribution=}")
 | 
			
		||||
                    print(f"distribution={distribution}")
 | 
			
		||||
                pl_module.add_prototypes(
 | 
			
		||||
                    distribution=distribution,
 | 
			
		||||
                    components_initializer=self.prototypes_initializer)
 | 
			
		||||
@@ -134,4 +134,4 @@ class GNGCallback(pl.Callback):
 | 
			
		||||
            pl_module.errors[
 | 
			
		||||
                worst_neighbor] = errors[worst_neighbor] * self.reduction
 | 
			
		||||
 | 
			
		||||
            trainer.accelerator_backend.setup_optimizers(trainer)
 | 
			
		||||
            trainer.accelerator.setup_optimizers(trainer)
 | 
			
		||||
 
 | 
			
		||||
@@ -48,7 +48,7 @@ class CBC(SiameseGLVQ):
 | 
			
		||||
        y_pred = self(x)
 | 
			
		||||
        num_classes = self.num_classes
 | 
			
		||||
        y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
 | 
			
		||||
        loss = self.loss(y_pred, y_true).mean(dim=0)
 | 
			
		||||
        loss = self.loss(y_pred, y_true).mean()
 | 
			
		||||
        return y_pred, loss
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
 
 | 
			
		||||
@@ -5,13 +5,12 @@ Mainly used for PytorchLightningCLI configurations.
 | 
			
		||||
"""
 | 
			
		||||
from typing import Any, Optional, Type
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
from torch.utils.data import DataLoader, Dataset, random_split
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
from torchvision.datasets import MNIST
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# MNIST
 | 
			
		||||
class MNISTDataModule(pl.LightningDataModule):
 | 
			
		||||
 
 | 
			
		||||
@@ -6,8 +6,8 @@ from torch.nn.parameter import Parameter
 | 
			
		||||
from ..core.competitions import wtac
 | 
			
		||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
			
		||||
from ..core.initializers import EyeTransformInitializer
 | 
			
		||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
			
		||||
from ..nn.activations import get_activation
 | 
			
		||||
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
 | 
			
		||||
from ..core.transforms import LinearTransform
 | 
			
		||||
from ..nn.wrappers import LambdaLayer, LossLayer
 | 
			
		||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
			
		||||
 | 
			
		||||
@@ -18,15 +18,16 @@ class GLVQ(SupervisedPrototypeModel):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Default hparams
 | 
			
		||||
        self.hparams.setdefault("margin", 0.0)
 | 
			
		||||
        self.hparams.setdefault("transfer_fn", "identity")
 | 
			
		||||
        self.hparams.setdefault("transfer_beta", 10.0)
 | 
			
		||||
 | 
			
		||||
        # Layers
 | 
			
		||||
        transfer_fn = get_activation(self.hparams.transfer_fn)
 | 
			
		||||
        self.transfer_layer = LambdaLayer(transfer_fn)
 | 
			
		||||
 | 
			
		||||
        # Loss
 | 
			
		||||
        self.loss = LossLayer(glvq_loss)
 | 
			
		||||
        self.loss = GLVQLoss(
 | 
			
		||||
            margin=self.hparams.margin,
 | 
			
		||||
            transfer_fn=self.hparams.transfer_fn,
 | 
			
		||||
            beta=self.hparams.transfer_beta,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def initialize_prototype_win_ratios(self):
 | 
			
		||||
        self.register_buffer(
 | 
			
		||||
@@ -54,10 +55,8 @@ class GLVQ(SupervisedPrototypeModel):
 | 
			
		||||
    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        out = self.compute_distances(x)
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        mu = self.loss(out, y, prototype_labels=plabels)
 | 
			
		||||
        batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
 | 
			
		||||
        loss = batch_loss.sum(dim=0)
 | 
			
		||||
        _, plabels = self.proto_layer()
 | 
			
		||||
        loss = self.loss(out, y, plabels)
 | 
			
		||||
        return out, loss
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
@@ -113,7 +112,8 @@ class SiameseGLVQ(GLVQ):
 | 
			
		||||
        proto_opt = self.optimizer(self.proto_layer.parameters(),
 | 
			
		||||
                                   lr=self.hparams.proto_lr)
 | 
			
		||||
        # Only add a backbone optimizer if backbone has trainable parameters
 | 
			
		||||
        if (bb_params := list(self.backbone.parameters())):
 | 
			
		||||
        bb_params = list(self.backbone.parameters())
 | 
			
		||||
        if (bb_params):
 | 
			
		||||
            bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
 | 
			
		||||
            optimizers = [proto_opt, bb_opt]
 | 
			
		||||
        else:
 | 
			
		||||
@@ -208,18 +208,22 @@ class SiameseGMLVQ(SiameseGLVQ):
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        # Override the backbone
 | 
			
		||||
        self.backbone = torch.nn.Linear(self.hparams.input_dim,
 | 
			
		||||
                                        self.hparams.latent_dim,
 | 
			
		||||
                                        bias=False)
 | 
			
		||||
        omega_initializer = kwargs.get("omega_initializer",
 | 
			
		||||
                                       EyeTransformInitializer())
 | 
			
		||||
        self.backbone = LinearTransform(
 | 
			
		||||
            self.hparams.input_dim,
 | 
			
		||||
            self.hparams.output_dim,
 | 
			
		||||
            initializer=omega_initializer,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def omega_matrix(self):
 | 
			
		||||
        return self.backbone.weight.detach().cpu()
 | 
			
		||||
        return self.backbone.weights
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def lambda_matrix(self):
 | 
			
		||||
        omega = self.backbone.weight  # (latent_dim, input_dim)
 | 
			
		||||
        lam = omega.T @ omega
 | 
			
		||||
        omega = self.backbone.weight  # (input_dim, latent_dim)
 | 
			
		||||
        lam = omega @ omega.T
 | 
			
		||||
        return lam.detach().cpu()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,8 @@
 | 
			
		||||
"""LVQ models that are optimized using non-gradient methods."""
 | 
			
		||||
 | 
			
		||||
from ..core.losses import _get_dp_dm
 | 
			
		||||
from ..nn.activations import get_activation
 | 
			
		||||
from ..nn.wrappers import LambdaLayer
 | 
			
		||||
from .abstract import NonGradientMixin
 | 
			
		||||
from .glvq import GLVQ
 | 
			
		||||
 | 
			
		||||
@@ -8,9 +10,7 @@ from .glvq import GLVQ
 | 
			
		||||
class LVQ1(NonGradientMixin, GLVQ):
 | 
			
		||||
    """Learning Vector Quantization 1."""
 | 
			
		||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        protos = self.proto_layer.components
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
 | 
			
		||||
        protos, plables = self.proto_layer()
 | 
			
		||||
        x, y = train_batch
 | 
			
		||||
        dis = self.compute_distances(x)
 | 
			
		||||
        # TODO Vectorized implementation
 | 
			
		||||
@@ -28,8 +28,8 @@ class LVQ1(NonGradientMixin, GLVQ):
 | 
			
		||||
            self.proto_layer.load_state_dict({"_components": updated_protos},
 | 
			
		||||
                                             strict=False)
 | 
			
		||||
 | 
			
		||||
        print(f"{dis=}")
 | 
			
		||||
        print(f"{y=}")
 | 
			
		||||
        print(f"dis={dis}")
 | 
			
		||||
        print(f"y={y}")
 | 
			
		||||
        # Logging
 | 
			
		||||
        self.log_acc(dis, y, tag="train_acc")
 | 
			
		||||
 | 
			
		||||
@@ -39,8 +39,7 @@ class LVQ1(NonGradientMixin, GLVQ):
 | 
			
		||||
class LVQ21(NonGradientMixin, GLVQ):
 | 
			
		||||
    """Learning Vector Quantization 2.1."""
 | 
			
		||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        protos = self.proto_layer.components
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        protos, plabels = self.proto_layer()
 | 
			
		||||
 | 
			
		||||
        x, y = train_batch
 | 
			
		||||
        dis = self.compute_distances(x)
 | 
			
		||||
@@ -66,4 +65,60 @@ class LVQ21(NonGradientMixin, GLVQ):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MedianLVQ(NonGradientMixin, GLVQ):
 | 
			
		||||
    """Median LVQ"""
 | 
			
		||||
    """Median LVQ
 | 
			
		||||
 | 
			
		||||
    # TODO Avoid computing distances over and over
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, hparams, verbose=True, **kwargs):
 | 
			
		||||
        self.verbose = verbose
 | 
			
		||||
        super().__init__(hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
        self.transfer_layer = LambdaLayer(
 | 
			
		||||
            get_activation(self.hparams.transfer_fn))
 | 
			
		||||
 | 
			
		||||
    def _f(self, x, y, protos, plabels):
 | 
			
		||||
        d = self.distance_layer(x, protos)
 | 
			
		||||
        dp, dm = _get_dp_dm(d, y, plabels)
 | 
			
		||||
        mu = (dp - dm) / (dp + dm)
 | 
			
		||||
        invmu = -1.0 * mu
 | 
			
		||||
        f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0
 | 
			
		||||
        return f
 | 
			
		||||
 | 
			
		||||
    def expectation(self, x, y, protos, plabels):
 | 
			
		||||
        f = self._f(x, y, protos, plabels)
 | 
			
		||||
        gamma = f / f.sum()
 | 
			
		||||
        return gamma
 | 
			
		||||
 | 
			
		||||
    def lower_bound(self, x, y, protos, plabels, gamma):
 | 
			
		||||
        f = self._f(x, y, protos, plabels)
 | 
			
		||||
        lower_bound = (gamma * f.log()).sum()
 | 
			
		||||
        return lower_bound
 | 
			
		||||
 | 
			
		||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        protos, plabels = self.proto_layer()
 | 
			
		||||
 | 
			
		||||
        x, y = train_batch
 | 
			
		||||
        dis = self.compute_distances(x)
 | 
			
		||||
 | 
			
		||||
        for i, _ in enumerate(protos):
 | 
			
		||||
            # Expectation step
 | 
			
		||||
            gamma = self.expectation(x, y, protos, plabels)
 | 
			
		||||
            lower_bound = self.lower_bound(x, y, protos, plabels, gamma)
 | 
			
		||||
 | 
			
		||||
            # Maximization step
 | 
			
		||||
            _protos = protos + 0
 | 
			
		||||
            for k, xk in enumerate(x):
 | 
			
		||||
                _protos[i] = xk
 | 
			
		||||
                _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
 | 
			
		||||
                if _lower_bound > lower_bound:
 | 
			
		||||
                    if self.verbose:
 | 
			
		||||
                        print(f"Updating prototype {i} to data {k}...")
 | 
			
		||||
                    self.proto_layer.load_state_dict({"_components": _protos},
 | 
			
		||||
                                                     strict=False)
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
        # Logging
 | 
			
		||||
        self.log_acc(dis, y, tag="train_acc")
 | 
			
		||||
 | 
			
		||||
        return None
 | 
			
		||||
 
 | 
			
		||||
@@ -20,11 +20,11 @@ class CELVQ(GLVQ):
 | 
			
		||||
    def shared_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        out = self.compute_distances(x)  # [None, num_protos]
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        _, plabels = self.proto_layer()
 | 
			
		||||
        winning = stratified_min_pooling(out, plabels)  # [None, num_classes]
 | 
			
		||||
        probs = -1.0 * winning
 | 
			
		||||
        batch_loss = self.loss(probs, y.long())
 | 
			
		||||
        loss = batch_loss.sum(dim=0)
 | 
			
		||||
        loss = batch_loss.sum()
 | 
			
		||||
        return out, loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -54,9 +54,9 @@ class ProbabilisticLVQ(GLVQ):
 | 
			
		||||
    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        out = self.forward(x)
 | 
			
		||||
        plabels = self.proto_layer.labels
 | 
			
		||||
        _, plabels = self.proto_layer()
 | 
			
		||||
        batch_loss = self.loss(out, y, plabels)
 | 
			
		||||
        loss = batch_loss.sum(dim=0)
 | 
			
		||||
        loss = batch_loss.sum()
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -92,5 +92,5 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
 | 
			
		||||
    #     x, y = batch
 | 
			
		||||
    #     y_pred = self(x)
 | 
			
		||||
    #     batch_loss = self.loss(y_pred, y)
 | 
			
		||||
    #     loss = batch_loss.sum(dim=0)
 | 
			
		||||
    #     loss = batch_loss.sum()
 | 
			
		||||
    #     return loss
 | 
			
		||||
 
 | 
			
		||||
@@ -53,7 +53,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
 | 
			
		||||
        grid = self._grid.view(-1, 2)
 | 
			
		||||
        gd = squared_euclidean_distance(wp, grid)
 | 
			
		||||
        nh = torch.exp(-gd / self._sigma**2)
 | 
			
		||||
        protos = self.proto_layer.components
 | 
			
		||||
        protos = self.proto_layer()
 | 
			
		||||
        diff = x.unsqueeze(dim=1) - protos
 | 
			
		||||
        delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
 | 
			
		||||
        updated_protos = protos + delta.sum(dim=0)
 | 
			
		||||
@@ -132,7 +132,7 @@ class GrowingNeuralGas(NeuralGas):
 | 
			
		||||
        mask[torch.arange(len(mask)), winner] = 1.0
 | 
			
		||||
        dp = d * mask
 | 
			
		||||
 | 
			
		||||
        self.errors += torch.sum(dp * dp, dim=0)
 | 
			
		||||
        self.errors += torch.sum(dp * dp)
 | 
			
		||||
        self.errors *= self.hparams.step_reduction
 | 
			
		||||
 | 
			
		||||
        self.topology_layer(d)
 | 
			
		||||
 
 | 
			
		||||
@@ -251,8 +251,6 @@ class VisImgComp(Vis2DAbstract):
 | 
			
		||||
                                   size=self.embedding_data,
 | 
			
		||||
                                   replace=False)
 | 
			
		||||
            data = self.x_train[ind]
 | 
			
		||||
            # print(f"{data.shape=}")
 | 
			
		||||
            # print(f"{self.y_train[ind].shape=}")
 | 
			
		||||
            tb.add_embedding(data.view(len(ind), -1),
 | 
			
		||||
                             label_img=data,
 | 
			
		||||
                             global_step=None,
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										11
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								setup.py
									
									
									
									
									
								
							@@ -22,7 +22,7 @@ with open("README.md", "r") as fh:
 | 
			
		||||
    long_description = fh.read()
 | 
			
		||||
 | 
			
		||||
INSTALL_REQUIRES = [
 | 
			
		||||
    "prototorch>=0.6.0",
 | 
			
		||||
    "prototorch>=0.7.0",
 | 
			
		||||
    "pytorch_lightning>=1.3.5",
 | 
			
		||||
    "torchmetrics",
 | 
			
		||||
]
 | 
			
		||||
@@ -46,14 +46,14 @@ EXAMPLES = [
 | 
			
		||||
    "scikit-learn",
 | 
			
		||||
]
 | 
			
		||||
TESTS = [
 | 
			
		||||
    "codecov",
 | 
			
		||||
    "pytest-cov",
 | 
			
		||||
    "pytest",
 | 
			
		||||
]
 | 
			
		||||
ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name=safe_name("prototorch_" + PLUGIN_NAME),
 | 
			
		||||
    version="0.2.0",
 | 
			
		||||
    version="0.3.0",
 | 
			
		||||
    description="Pre-packaged prototype-based "
 | 
			
		||||
    "machine learning models using ProtoTorch and PyTorch-Lightning.",
 | 
			
		||||
    long_description=long_description,
 | 
			
		||||
@@ -63,7 +63,7 @@ setup(
 | 
			
		||||
    url=PROJECT_URL,
 | 
			
		||||
    download_url=DOWNLOAD_URL,
 | 
			
		||||
    license="MIT",
 | 
			
		||||
    python_requires=">=3.9",
 | 
			
		||||
    python_requires=">=3.6",
 | 
			
		||||
    install_requires=INSTALL_REQUIRES,
 | 
			
		||||
    extras_require={
 | 
			
		||||
        "dev": DEV,
 | 
			
		||||
@@ -80,6 +80,9 @@ setup(
 | 
			
		||||
        "License :: OSI Approved :: MIT License",
 | 
			
		||||
        "Natural Language :: English",
 | 
			
		||||
        "Programming Language :: Python :: 3.9",
 | 
			
		||||
        "Programming Language :: Python :: 3.8",
 | 
			
		||||
        "Programming Language :: Python :: 3.7",
 | 
			
		||||
        "Programming Language :: Python :: 3.6",
 | 
			
		||||
        "Operating System :: OS Independent",
 | 
			
		||||
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
			
		||||
        "Topic :: Software Development :: Libraries",
 | 
			
		||||
 
 | 
			
		||||
@@ -1,11 +1,27 @@
 | 
			
		||||
#! /bin/bash
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Read Flags
 | 
			
		||||
gpu=0
 | 
			
		||||
while [ -n "$1" ]; do
 | 
			
		||||
	case "$1" in
 | 
			
		||||
	    --gpu) gpu=1;;
 | 
			
		||||
	    -g) gpu=1;;
 | 
			
		||||
        *) path=$1;;
 | 
			
		||||
	esac
 | 
			
		||||
	shift
 | 
			
		||||
done
 | 
			
		||||
 | 
			
		||||
python --version
 | 
			
		||||
echo "Using GPU: " $gpu
 | 
			
		||||
 | 
			
		||||
# Loop
 | 
			
		||||
failed=0
 | 
			
		||||
 | 
			
		||||
for example in $(find $1 -maxdepth 1 -name "*.py")
 | 
			
		||||
for example in $(find $path -maxdepth 1 -name "*.py")
 | 
			
		||||
do
 | 
			
		||||
    echo  -n "$x" $example '... '
 | 
			
		||||
    export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt
 | 
			
		||||
    export DISPLAY= && python $example --fast_dev_run 1 --gpus $gpu &> run_log.txt
 | 
			
		||||
    if [[ $? -ne 0 ]]; then
 | 
			
		||||
        echo "FAILED!!"
 | 
			
		||||
        cat run_log.txt
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user