Daankular commited on
Commit
ad8a35e
·
1 Parent(s): 14c3d13

Fix chumpy build isolation: vendor with patched setup.py, revert to Gradio SDK

Browse files

chumpy's setup.py does `from pip._internal.req import parse_requirements`
which fails in pip's default isolated build env. Vendored locally with
setup.py patched to use a hardcoded install_requires instead.
Reverted from docker SDK back to gradio SDK (ZeroGPU requires Gradio SDK).

Dockerfile DELETED
@@ -1,35 +0,0 @@
1
- FROM python:3.10-slim
2
-
3
- # System deps
4
- RUN apt-get update && apt-get install -y --no-install-recommends \
5
- git wget curl build-essential cmake ninja-build pkg-config \
6
- libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev ffmpeg \
7
- && rm -rf /var/lib/apt/lists/*
8
-
9
- # HF user setup
10
- RUN useradd -m -u 1000 user
11
- USER user
12
- ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
13
- WORKDIR $HOME/app
14
-
15
- # Upgrade pip first
16
- RUN pip install --user --upgrade pip setuptools wheel
17
-
18
- # chumpy must be installed with --no-build-isolation BEFORE everything else
19
- # (its setup.py does `import pip` which fails in pip's default isolated build env)
20
- RUN pip install --user --no-build-isolation \
21
- "chumpy @ git+https://github.com/mattloper/chumpy.git@580566eafc9ac68b2614b64d6f7aaa84eebb70da"
22
-
23
- # Copy app files
24
- COPY --chown=user . $HOME/app
25
-
26
- # Install remaining requirements (chumpy already satisfied above)
27
- RUN pip install --user --no-cache-dir -r requirements.txt \
28
- "torch<=2.9.1" \
29
- "gradio[oauth,mcp]==6.11.0" \
30
- "uvicorn>=0.14.0" \
31
- "websockets>=10.4" \
32
- "spaces==0.48.1"
33
-
34
- EXPOSE 7860
35
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -3,8 +3,9 @@ title: Image2Model
3
  emoji: 🎭
4
  colorFrom: purple
5
  colorTo: blue
6
- sdk: docker
7
- app_port: 7860
 
8
  pinned: false
9
  license: apache-2.0
10
  hardware: zero-a10g
 
3
  emoji: 🎭
4
  colorFrom: purple
5
  colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: "6.11.0"
8
+ app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  hardware: zero-a10g
requirements.txt CHANGED
@@ -1,13 +1,12 @@
1
- # HuggingFace ZeroGPU Space — Docker SDK
2
- # chumpy is pre-installed in the Dockerfile with --no-build-isolation
3
- # (its setup.py does `import pip` which breaks in modern pip isolated builds)
4
  spaces
5
 
6
  # Git-pinned installs
7
  hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c
8
  clip @ git+https://github.com/openai/CLIP.git@d05afc436d78f1c48dc0dbf8e5980a9d471f35f6
9
  mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be
10
- chumpy @ git+https://github.com/mattloper/chumpy.git@580566eafc9ac68b2614b64d6f7aaa84eebb70da
11
  skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95
12
  nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae
13
  diso @ git+https://github.com/SarahWeiii/diso.git@9792ad928ccb09bdec938779651ee03e395758a6
 
1
+ # HuggingFace ZeroGPU Space — Gradio SDK
2
+ # chumpy vendored locally with patched setup.py (original does `import pip` which breaks isolated builds)
 
3
  spaces
4
 
5
  # Git-pinned installs
6
  hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c
7
  clip @ git+https://github.com/openai/CLIP.git@d05afc436d78f1c48dc0dbf8e5980a9d471f35f6
8
  mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be
9
+ chumpy @ ./vendor/chumpy
10
  skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95
11
  nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae
12
  diso @ git+https://github.com/SarahWeiii/diso.git@9792ad928ccb09bdec938779651ee03e395758a6
vendor/chumpy/.circleci/config.yml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2.1
2
+
3
+ orbs:
4
+ python: circleci/python@2.1.1 # optional, for helpers
5
+
6
+ jobs:
7
+ python3:
8
+ docker:
9
+ - image: cimg/python:3.12
10
+ steps:
11
+ - checkout
12
+ - run:
13
+ name: Install system deps
14
+ command: |
15
+ sudo apt-get update
16
+ sudo apt-get install -y --no-install-recommends gfortran liblapack-dev
17
+ - restore_cache:
18
+ keys:
19
+ - v1-pip-{{ arch }}-{{ .Branch }}-{{ checksum "requirements.txt" }}
20
+ - v1-pip-{{ arch }}-
21
+ - run:
22
+ name: Install python deps
23
+ command: |
24
+ python -m venv venv
25
+ . venv/bin/activate
26
+ pip install -U pip
27
+ set -o pipefail; pip install -r requirements.txt | cat
28
+ - save_cache:
29
+ key: v1-pip-{{ arch }}-{{ .Branch }}-{{ checksum "requirements.txt" }}
30
+ paths:
31
+ - ~/.cache/pip
32
+ - run:
33
+ name: Show versions
34
+ command: |
35
+ . venv/bin/activate
36
+ pip freeze
37
+ - run:
38
+ name: Run tests
39
+ command: |
40
+ . venv/bin/activate
41
+ make test
42
+
43
+ workflows:
44
+ version: 2
45
+ on-commit:
46
+ jobs:
47
+ - python3
48
+ daily:
49
+ triggers:
50
+ - schedule:
51
+ cron: "0 17 * * *"
52
+ filters:
53
+ branches:
54
+ only: master
55
+ jobs:
56
+ - python3
vendor/chumpy/.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.gitignore.io/api/osx,python
3
+
4
+ ### OSX ###
5
+ # General
6
+ .DS_Store
7
+ .AppleDouble
8
+ .LSOverride
9
+
10
+ # Icon must end with two \r
11
+ Icon
12
+
13
+ # Thumbnails
14
+ ._*
15
+
16
+ # Files that might appear in the root of a volume
17
+ .DocumentRevisions-V100
18
+ .fseventsd
19
+ .Spotlight-V100
20
+ .TemporaryItems
21
+ .Trashes
22
+ .VolumeIcon.icns
23
+ .com.apple.timemachine.donotpresent
24
+
25
+ # Directories potentially created on remote AFP share
26
+ .AppleDB
27
+ .AppleDesktop
28
+ Network Trash Folder
29
+ Temporary Items
30
+ .apdisk
31
+
32
+ ### Python ###
33
+ # Byte-compiled / optimized / DLL files
34
+ __pycache__/
35
+ *.py[cod]
36
+ *$py.class
37
+
38
+ # C extensions
39
+ *.so
40
+
41
+ # Distribution / packaging
42
+ .Python
43
+ build/
44
+ develop-eggs/
45
+ dist/
46
+ downloads/
47
+ eggs/
48
+ .eggs/
49
+ lib/
50
+ lib64/
51
+ parts/
52
+ sdist/
53
+ var/
54
+ wheels/
55
+ *.egg-info/
56
+ .installed.cfg
57
+ *.egg
58
+ MANIFEST
59
+
60
+ # PyInstaller
61
+ # Usually these files are written by a python script from a template
62
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
63
+ *.manifest
64
+ *.spec
65
+
66
+ # Installer logs
67
+ pip-log.txt
68
+ pip-delete-this-directory.txt
69
+
70
+ # Unit test / coverage reports
71
+ htmlcov/
72
+ .tox/
73
+ .coverage
74
+ .coverage.*
75
+ .cache
76
+ nosetests.xml
77
+ coverage.xml
78
+ *.cover
79
+ .hypothesis/
80
+ .pytest_cache/
81
+
82
+ # Translations
83
+ *.mo
84
+ *.pot
85
+
86
+ # Django stuff:
87
+ *.log
88
+ local_settings.py
89
+ db.sqlite3
90
+
91
+ # Flask stuff:
92
+ instance/
93
+ .webassets-cache
94
+
95
+ # Scrapy stuff:
96
+ .scrapy
97
+
98
+ # Sphinx documentation
99
+ docs/_build/
100
+
101
+ # PyBuilder
102
+ target/
103
+
104
+ # Jupyter Notebook
105
+ .ipynb_checkpoints
106
+
107
+ # pyenv
108
+ .python-version
109
+
110
+ # celery beat schedule file
111
+ celerybeat-schedule
112
+
113
+ # SageMath parsed files
114
+ *.sage.py
115
+
116
+ # Environments
117
+ .env
118
+ .venv
119
+ env/
120
+ venv/
121
+ ENV/
122
+ env.bak/
123
+ venv.bak/
124
+
125
+ # Spyder project settings
126
+ .spyderproject
127
+ .spyproject
128
+
129
+ # Rope project settings
130
+ .ropeproject
131
+
132
+ # mkdocs documentation
133
+ /site
134
+
135
+ # mypy
136
+ .mypy_cache/
137
+
138
+ ### Python Patch ###
139
+ .venv/
140
+
141
+
142
+ # End of https://www.gitignore.io/api/osx,python
vendor/chumpy/LICENSE.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2014 Max-Planck-Gesellschaft
4
+ Copyright (c) 2014 Matthew Loper
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in
14
+ all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22
+ THE SOFTWARE.
vendor/chumpy/MANIFEST.in ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ global-include . *.py *.c *.h Makefile *.pyx requirements.txt
2
+ global-exclude chumpy/optional_test_performance.py
3
+ prune dist
vendor/chumpy/Makefile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ all:
2
+
3
+ upload:
4
+ rm -r dist
5
+ python setup.py sdist
6
+ twine upload dist/*
7
+
8
+ test:
9
+ # For some reason the import changes for Python 3 caused the Python 2 test
10
+ # loader to give up without loading any tests. So we discover them ourselves.
11
+ # python -m unittest
12
+ find chumpy -name 'test_*.py' | sed -e 's/\.py$$//' -e 's/\//./' | xargs python -m unittest
13
+
14
+ coverage: clean qcov
15
+ qcov: all
16
+ env LD_PRELOAD=$(PRELOADED) coverage run --source=. -m unittest discover -s .
17
+ coverage html
18
+ coverage report -m
vendor/chumpy/README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chumpy
2
+ ======
3
+
4
+ [![version](https://img.shields.io/pypi/v/chumpy?style=flat-square)][pypi]
5
+ [![license](https://img.shields.io/pypi/l/chumpy?style=flat-square)][pypi]
6
+ [![python versions](https://img.shields.io/pypi/pyversions/chumpy?style=flat-square)][pypi]
7
+ [![build status](https://img.shields.io/circleci/project/github/mattloper/chumpy/master?style=flat-square)][circle]
8
+
9
+ Autodifferentiation tool for Python.
10
+
11
+ [circle]: https://circleci.com/gh/mattloper/chumpy
12
+ [pypi]: https://pypi.org/project/chumpy/
13
+
14
+
15
+ Installation
16
+ ------------
17
+
18
+ Install the fork:
19
+
20
+ ```sh
21
+ pip install chumpy
22
+ ```
23
+
24
+ Import it:
25
+
26
+ ```py
27
+ import chumpy as ch
28
+ ```
29
+
30
+ Overview
31
+ --------
32
+
33
+ Chumpy is a Python-based framework designed to handle the **auto-differentiation** problem,
34
+ which is to evaluate an expression and its derivatives with respect to its inputs, by use of the chain rule.
35
+
36
+ Chumpy is intended to make construction and local
37
+ minimization of objectives easier.
38
+
39
+ Specifically, it provides:
40
+
41
+ - Easy problem construction by using Numpy’s application interface
42
+ - Easy access to derivatives via auto differentiation
43
+ - Easy local optimization methods (12 of them: most of which use the derivatives)
44
+
45
+
46
+ Usage
47
+ -----
48
+
49
+ Chumpy comes with its own demos, which can be seen by typing the following:
50
+
51
+ ```python
52
+ import chumpy
53
+ chumpy.demo() # prints out a list of possible demos
54
+ ```
55
+
56
+
57
+ License
58
+ -------
59
+
60
+ This project is licensed under the MIT License.
vendor/chumpy/chumpy/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ch import *
2
+ from .logic import *
3
+
4
+ from .optimization import minimize
5
+ from . import extras
6
+ from . import testing
7
+ from .version import version as __version__
8
+
9
+ from .version import version as __version__
10
+
11
+
12
+ def test():
13
+ from os.path import split
14
+ import unittest
15
+ test_loader= unittest.TestLoader()
16
+ test_loader = test_loader.discover(split(__file__)[0])
17
+ test_runner = unittest.TextTestRunner()
18
+ test_runner.run( test_loader )
19
+
20
+
21
+ demos = {}
22
+
23
+ demos['scalar'] = """
24
+ import chumpy as ch
25
+
26
+ [x1, x2, x3] = ch.array(10), ch.array(20), ch.array(30)
27
+ result = x1+x2+x3
28
+ print result # prints [ 60.]
29
+ print result.dr_wrt(x1) # prints 1
30
+ """
31
+
32
+ demos['show_tree'] = """
33
+ import chumpy as ch
34
+
35
+ [x1, x2, x3] = ch.array(10), ch.array(20), ch.array(30)
36
+ for i in range(3): x2 = x1 + x2 + x3
37
+
38
+ x2.dr_wrt(x1) # pull cache
39
+ x2.dr_wrt(x3) # pull cache
40
+ x1.label='x1' # for clarity in show_tree()
41
+ x2.label='x2' # for clarity in show_tree()
42
+ x3.label='x3' # for clarity in show_tree()
43
+ x2.show_tree(cachelim=1e-4) # in MB
44
+ """
45
+
46
+ demos['matrix'] = """
47
+ import chumpy as ch
48
+
49
+ x1, x2, x3, x4 = ch.eye(10), ch.array(1), ch.array(5), ch.array(10)
50
+ y = x1*(x2-x3)+x4
51
+ print y
52
+ print y.dr_wrt(x2)
53
+ """
54
+
55
+ demos['linalg'] = """
56
+ import chumpy as ch
57
+
58
+ m = [ch.random.randn(100).reshape((10,10)) for i in range(3)]
59
+ y = m[0].dot(m[1]).dot(ch.linalg.inv(m[2])) * ch.linalg.det(m[0])
60
+ print y.shape
61
+ print y.dr_wrt(m[0]).shape
62
+ """
63
+
64
+ demos['inheritance'] = """
65
+ import chumpy as ch
66
+ import numpy as np
67
+
68
+ class Sin(ch.Ch):
69
+
70
+ dterms = ('x',)
71
+
72
+ def compute_r(self):
73
+ return np.sin(self.x.r)
74
+
75
+ def compute_dr_wrt(self, wrt):
76
+ import scipy.sparse
77
+ if wrt is self.x:
78
+ result = np.cos(self.x.r)
79
+ return scipy.sparse.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
80
+
81
+ x1 = Ch([10,20,30])
82
+ result = Sin(x1) # or "result = Sin(x=x1)"
83
+ print result.r
84
+ print result.dr_wrt(x1)
85
+ """
86
+
87
+ demos['optimization'] = """
88
+ import chumpy as ch
89
+
90
+ x = ch.zeros(10)
91
+ y = ch.zeros(10)
92
+
93
+ # Beale's function
94
+ e1 = 1.5 - x + x*y
95
+ e2 = 2.25 - x + x*(y**2)
96
+ e3 = 2.625 - x + x*(y**3)
97
+
98
+ objective = {'e1': e1, 'e2': e2, 'e3': e3}
99
+ ch.minimize(objective, x0=[x,y], method='dogleg')
100
+ print x # should be all 3.0
101
+ print y # should be all 0.5
102
+ """
103
+
104
+
105
+
106
+
107
+ def demo(which=None):
108
+ if which not in demos:
109
+ print('Please indicate which demo you want, as follows:')
110
+ for key in demos:
111
+ print("\tdemo('%s')" % (key,))
112
+ return
113
+
114
+ print('- - - - - - - - - - - <CODE> - - - - - - - - - - - -')
115
+ print(demos[which])
116
+ print('- - - - - - - - - - - </CODE> - - - - - - - - - - - -\n')
117
+ exec('global np\n' + demos[which], globals(), locals())
vendor/chumpy/chumpy/api_compatibility.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author(s): Matthew Loper
3
+
4
+ See LICENCE.txt for licensing and contact information.
5
+ """
6
+
7
+
8
+ from . import ch
9
+ import numpy as np
10
+ from os.path import join, split
11
+ from six import StringIO
12
+ import numpy
13
+ import chumpy
14
+ from six.moves import cPickle as pickle
15
+
16
+ src = ''
17
+ num_passed = 0
18
+ num_not_passed = 0
19
+ which_passed = []
20
+
21
+ def r(fn_name, args_req, args_opt, nplib=numpy, chlib=chumpy):
22
+ global num_passed, num_not_passed
23
+ result = [None, None]
24
+
25
+ for lib in [nplib, chlib]:
26
+
27
+ # if fn_name is 'svd' and lib is chlib:
28
+ # import pdb; pdb.set_trace()
29
+ if lib is nplib:
30
+ fn = getattr(lib, fn_name)
31
+ else:
32
+ try:
33
+ fn = getattr(lib, fn_name)
34
+ except AttributeError:
35
+ result[0] = 'missing'
36
+ result[1] = 'missing'
37
+ num_not_passed += 1
38
+ continue
39
+ try:
40
+ if isinstance(args_req, dict):
41
+ _ = fn(**args_req)
42
+ else:
43
+ _ = fn(*args_req)
44
+ if lib is chlib:
45
+ result[0] = 'passed'
46
+ num_passed += 1
47
+ global which_passed
48
+ which_passed.append(fn_name)
49
+
50
+ if hasattr(_, 'dterms'):
51
+ try:
52
+ _.r
53
+
54
+ try:
55
+ pickle.dumps(_)
56
+ except:
57
+ result[0] += ' (but unpickleable!)'
58
+ except:
59
+ import pdb; pdb.set_trace()
60
+ result[0] += '(but cant get result!)'
61
+ except Exception as e:
62
+ if e is TypeError:
63
+ import pdb; pdb.set_trace()
64
+ if lib is nplib:
65
+ import pdb; pdb.set_trace()
66
+ else:
67
+ num_not_passed += 1
68
+ # if fn_name == 'rot90':
69
+ # import pdb; pdb.set_trace()
70
+ result[0] = e.__class__.__name__
71
+
72
+ try:
73
+ if isinstance(args_req, dict):
74
+ fn(**dict(list(args_req.items()) + list(args_opt.items())))
75
+ else:
76
+ fn(*args_req, **args_opt)
77
+ if lib is chlib:
78
+ result[1] = 'passed'
79
+ except Exception as e:
80
+ if e is TypeError:
81
+ import pdb; pdb.set_trace()
82
+ result[1] = e.__class__.__name__
83
+
84
+ # print '%s: %s, %s' % (fn_name, result[0], result[1])
85
+
86
+ append(fn_name, result[0], result[1])
87
+
88
+ def make_row(a, b, c, b_color, c_color):
89
+ global src
90
+ src += '<tr><td>%s</td><td style="background-color:%s">%s</td><td style="background-color:%s">%s</td></tr>' % (a,b_color, b,c_color, c)
91
+
92
+ def append(a, b, c):
93
+ global src
94
+ b_color = 'white'
95
+ c_color = 'white'
96
+
97
+ b = b.replace('NotImplementedError', 'not yet implemented')
98
+ c = c.replace('NotImplementedError', 'not yet implemented')
99
+ b = b.replace('WontImplement', "won't implement")
100
+ c = c.replace('WontImplement', "won't implement")
101
+ lookup = {
102
+ 'passed': 'lightgreen',
103
+ "won't implement": 'lightgray',
104
+ 'untested': 'lightyellow',
105
+ 'not yet implemented': 'pink'
106
+ }
107
+
108
+ b_color = lookup[b] if b in lookup else 'white'
109
+ c_color = lookup[c] if c in lookup else 'white'
110
+
111
+ print('%s: %s, %s' % (a,b,c))
112
+ make_row(a, b, c, b_color, c_color)
113
+
114
+ def m(s):
115
+ append(s, 'unknown', 'unknown')
116
+ global num_not_passed
117
+ num_not_passed += 1
118
+
119
+ def hd3(s):
120
+ global src
121
+ src += '<tr><td colspan=3><h3 style="margin-bottom:0;">%s</h3></td></tr>' % (s,)
122
+
123
+ def hd2(s):
124
+ global src
125
+ src += '</table><br/><br/><table border=1>'
126
+ src += '<tr><td colspan=3 style="background-color:black;color:white"><h2 style="margin-bottom:0;">%s</h2></td></tr>' % (s,)
127
+
128
+ def main():
129
+
130
+ #sample_array
131
+
132
+ ###############################
133
+ hd2('Array Creation Routines')
134
+
135
+ hd3('Ones and zeros')
136
+
137
+ r('empty', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
138
+ r('empty_like', {'prototype': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
139
+ r('eye', {'N': 10}, {'M': 5, 'k': 0, 'dtype': np.float64})
140
+ r('identity', {'n': 10}, {'dtype': np.float64})
141
+ r('ones', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
142
+ r('ones_like', {'a': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
143
+ r('zeros', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
144
+ r('zeros_like', {'a': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
145
+
146
+ hd3('From existing data')
147
+ r('array', {'object': [1,2,3]}, {'dtype': np.float64, 'order': 'C', 'subok': False, 'ndmin': 2})
148
+ r('asarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
149
+ r('asanyarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
150
+ r('ascontiguousarray', {'a': np.array([1,2,3])}, {'dtype': np.float64})
151
+ r('asmatrix', {'data': np.array([1,2,3])}, {'dtype': np.float64})
152
+ r('copy', (np.array([1,2,3]),), {})
153
+ r('frombuffer', {'buffer': np.array([1,2,3])}, {})
154
+ m('fromfile')
155
+ r('fromfunction', {'function': lambda i, j: i + j, 'shape': (3, 3)}, {'dtype': np.float64})
156
+ # function, shape, **kwargs
157
+ # lambda i, j: i + j, (3, 3), dtype=int
158
+ r('fromiter', {'iter': [1,2,3,4], 'dtype': np.float64}, {'count': 2})
159
+ r('fromstring', {'string': '\x01\x02', 'dtype': np.uint8}, {})
160
+ r('loadtxt', {'fname': StringIO("0 1\n2 3")}, {})
161
+
162
+ hd3('Creating record arrays (wont be implemented)')
163
+ hd3('Creating character arrays (wont be implemented)')
164
+
165
+ hd3('Numerical ranges')
166
+ r('arange', {'start': 0, 'stop': 10}, {'step': 2, 'dtype': np.float64})
167
+ r('linspace', {'start': 0, 'stop': 10}, {'num': 2, 'endpoint': 10, 'retstep': 1})
168
+ r('logspace', {'start': 0, 'stop': 10}, {'num': 2, 'endpoint': 10, 'base': 1})
169
+ r('meshgrid', ([1,2,3], [4,5,6]), {})
170
+ m('mgrid')
171
+ m('ogrid')
172
+
173
+ hd3('Building matrices')
174
+ r('diag', {'v': np.arange(9).reshape((3,3))}, {'k': 0})
175
+ r('diagflat', {'v': [[1,2], [3,4]]}, {})
176
+ r('tri', {'N': 3}, {'M': 5, 'k': 2, 'dtype': np.float64})
177
+ r('tril', {'m': [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]}, {'k': -1})
178
+ r('triu', {'m': [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]}, {'k': -1})
179
+ r('vander', {'x': np.array([1, 2, 3, 5])}, {'N': 3})
180
+
181
+ ###############################
182
+ hd2('Array manipulation routines')
183
+
184
+ hd3('Basic operations')
185
+ r('copyto', {'dst': np.eye(3), 'src': np.eye(3)}, {})
186
+
187
+ hd3('Changing array shape')
188
+ r('reshape', {'a': np.eye(3), 'newshape': (9,)}, {'order' : 'C'})
189
+ r('ravel', {'a': np.eye(3)}, {'order' : 'C'})
190
+ m('flat')
191
+ m('flatten')
192
+
193
+ hd3('Transpose-like operations')
194
+ r('rollaxis', {'a': np.ones((3,4,5,6)), 'axis': 3}, {'start': 0})
195
+ r('swapaxes', {'a': np.array([[1,2,3]]), 'axis1': 0, 'axis2': 1}, {})
196
+ r('transpose', {'a': np.arange(4).reshape((2,2))}, {'axes': (1,0)})
197
+
198
+ hd3('Changing number of dimensions')
199
+ r('atleast_1d', (np.eye(3),), {})
200
+ r('atleast_2d', (np.eye(3),), {})
201
+ r('atleast_3d', (np.eye(3),), {})
202
+ m('broadcast')
203
+ m('broadcast_arrays')
204
+ r('expand_dims', (np.array([1,2]),2), {})
205
+ r('squeeze', {'a': (np.array([[[1,2,3]]]))}, {})
206
+
207
+ hd3('Changing kind of array')
208
+ r('asarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
209
+ r('asanyarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
210
+ r('asmatrix', {'data': np.array([1,2,3])}, {})
211
+ r('asfarray', {'a': np.array([1,2,3])}, {})
212
+ r('asfortranarray', {'a': np.array([1,2,3])}, {})
213
+ r('asscalar', {'a': np.array([24])}, {})
214
+ r('require', {'a': np.array([24])}, {})
215
+
216
+ hd3('Joining arrays')
217
+ m('column_stack')
218
+ r('concatenate', ((np.eye(3), np.eye(3)),1), {})
219
+ r('dstack', ((np.eye(3), np.eye(3)),), {})
220
+ r('hstack', ((np.eye(3), np.eye(3)),), {})
221
+ r('vstack', ((np.eye(3), np.eye(3)),), {})
222
+
223
+ hd3('Splitting arrays')
224
+ m('array_split')
225
+ m('dsplit')
226
+ m('hsplit')
227
+ m('split')
228
+ m('vsplit')
229
+
230
+ hd3('Tiling arrays')
231
+ r('tile', (np.array([0, 1, 2]),2), {})
232
+ r('repeat', (np.array([[1,2],[3,4]]), 3), {'axis': 1})
233
+
234
+ hd3('Adding and removing elements')
235
+ m('delete')
236
+ m('insert')
237
+ m('append')
238
+ m('resize')
239
+ m('trim_zeros')
240
+ m('unique')
241
+
242
+ hd3('Rearranging elements')
243
+ r('fliplr', (np.eye(3),), {})
244
+ r('flipud', (np.eye(3),), {})
245
+ r('reshape', {'a': np.eye(3), 'newshape': (9,)}, {'order' : 'C'})
246
+ r('roll', (np.arange(10), 2), {})
247
+ r('rot90', (np.arange(4).reshape((2,2)),), {})
248
+
249
+ ###############################
250
+ hd2('Linear algebra (numpy.linalg)')
251
+
252
+ extra_args = {'nplib': numpy.linalg, 'chlib': ch.linalg}
253
+
254
+ hd3('Matrix and dot products')
255
+ r('dot', {'a': np.eye(3), 'b': np.eye(3)}, {})
256
+ r('dot', {'a': np.eye(3).ravel(), 'b': np.eye(3).ravel()}, {})
257
+ r('vdot', (np.eye(3).ravel(), np.eye(3).ravel()), {})
258
+ r('inner', (np.eye(3).ravel(), np.eye(3).ravel()), {})
259
+ r('outer', (np.eye(3).ravel(), np.eye(3).ravel()), {})
260
+ r('tensordot', {'a': np.eye(3), 'b': np.eye(3)}, {})
261
+ m('einsum')
262
+ r('matrix_power', {'M': np.eye(3), 'n': 2}, {}, **extra_args)
263
+ r('kron', {'a': np.eye(3), 'b': np.eye(3)}, {})
264
+
265
+ hd3('Decompositions')
266
+ r('cholesky', {'a': np.eye(3)}, {}, **extra_args)
267
+ r('qr', {'a': np.eye(3)}, {}, **extra_args)
268
+ r('svd', (np.eye(3),), {}, **extra_args)
269
+
270
+ hd3('Matrix eigenvalues')
271
+ r('eig', (np.eye(3),), {}, **extra_args)
272
+ r('eigh', (np.eye(3),), {}, **extra_args)
273
+ r('eigvals', (np.eye(3),), {}, **extra_args)
274
+ r('eigvalsh', (np.eye(3),), {}, **extra_args)
275
+
276
+ hd3('Norms and other numbers')
277
+ r('norm', (np.eye(3),), {}, **extra_args)
278
+ r('cond', (np.eye(3),), {}, **extra_args)
279
+ r('det', (np.eye(3),), {}, **extra_args)
280
+ r('slogdet', (np.eye(3),), {}, **extra_args)
281
+ r('trace', (np.eye(3),), {})
282
+
283
+ hd3('Solving equations and inverting matrices')
284
+ r('solve', (np.eye(3),np.ones(3)), {}, **extra_args)
285
+ r('tensorsolve', (np.eye(3),np.ones(3)), {}, **extra_args)
286
+ r('lstsq', (np.eye(3),np.ones(3)), {}, **extra_args)
287
+ r('inv', (np.eye(3),), {}, **extra_args)
288
+ r('pinv', (np.eye(3),), {}, **extra_args)
289
+ r('tensorinv', (np.eye(4*6).reshape((4,6,8,3)),), {'ind': 2}, **extra_args)
290
+
291
+ hd2('Mathematical functions')
292
+
293
+ hd3('Trigonometric functions')
294
+ r('sin', (np.arange(3),), {})
295
+ r('cos', (np.arange(3),), {})
296
+ r('tan', (np.arange(3),), {})
297
+ r('arcsin', (np.arange(3)/3.,), {})
298
+ r('arccos', (np.arange(3)/3.,), {})
299
+ r('arctan', (np.arange(3)/3.,), {})
300
+ r('hypot', (np.arange(3),np.arange(3)), {})
301
+ r('arctan2', (np.arange(3),np.arange(3)), {})
302
+ r('degrees', (np.arange(3),), {})
303
+ r('radians', (np.arange(3),), {})
304
+ r('unwrap', (np.arange(3),), {})
305
+ r('unwrap', (np.arange(3),), {})
306
+ r('deg2rad', (np.arange(3),), {})
307
+ r('rad2deg', (np.arange(3),), {})
308
+
309
+ hd3('Hyperbolic functions')
310
+ r('sinh', (np.arange(3),), {})
311
+ r('cosh', (np.arange(3),), {})
312
+ r('tanh', (np.arange(3),), {})
313
+ r('arcsinh', (np.arange(3)/9.,), {})
314
+ r('arccosh', (-np.arange(3)/9.,), {})
315
+ r('arctanh', (np.arange(3)/9.,), {})
316
+
317
+ hd3('Rounding')
318
+ r('around', (np.arange(3),), {})
319
+ r('round_', (np.arange(3),), {})
320
+ r('rint', (np.arange(3),), {})
321
+ r('fix', (np.arange(3),), {})
322
+ r('floor', (np.arange(3),), {})
323
+ r('ceil', (np.arange(3),), {})
324
+ r('trunc', (np.arange(3),), {})
325
+
326
+ hd3('Sums, products, differences')
327
+ r('prod', (np.arange(3),), {})
328
+ r('sum', (np.arange(3),), {})
329
+ r('nansum', (np.arange(3),), {})
330
+ r('cumprod', (np.arange(3),), {})
331
+ r('cumsum', (np.arange(3),), {})
332
+ r('diff', (np.arange(3),), {})
333
+ r('ediff1d', (np.arange(3),), {})
334
+ r('gradient', (np.arange(3),), {})
335
+ r('cross', (np.arange(3), np.arange(3)), {})
336
+ r('trapz', (np.arange(3),), {})
337
+
338
+ hd3('Exponents and logarithms')
339
+ r('exp', (np.arange(3),), {})
340
+ r('expm1', (np.arange(3),), {})
341
+ r('exp2', (np.arange(3),), {})
342
+ r('log', (np.arange(3),), {})
343
+ r('log10', (np.arange(3),), {})
344
+ r('log2', (np.arange(3),), {})
345
+ r('log1p', (np.arange(3),), {})
346
+ r('logaddexp', (np.arange(3), np.arange(3)), {})
347
+ r('logaddexp2', (np.arange(3), np.arange(3)), {})
348
+
349
+ hd3('Other special functions')
350
+ r('i0', (np.arange(3),), {})
351
+ r('sinc', (np.arange(3),), {})
352
+
353
+ hd3('Floating point routines')
354
+ r('signbit', (np.arange(3),), {})
355
+ r('copysign', (np.arange(3), np.arange(3)), {})
356
+ r('frexp', (np.arange(3),), {})
357
+ r('ldexp', (np.arange(3), np.arange(3)), {})
358
+
359
+ hd3('Arithmetic operations')
360
+ r('add', (np.arange(3), np.arange(3)), {})
361
+ r('reciprocal', (np.arange(3),), {})
362
+ r('negative', (np.arange(3),), {})
363
+ r('multiply', (np.arange(3), np.arange(3)), {})
364
+ r('divide', (np.arange(3), np.arange(3)), {})
365
+ r('power', (np.arange(3), np.arange(3)), {})
366
+ r('subtract', (np.arange(3), np.arange(3)), {})
367
+ r('true_divide', (np.arange(3), np.arange(3)), {})
368
+ r('floor_divide', (np.arange(3), np.arange(3)), {})
369
+ r('fmod', (np.arange(3), np.arange(3)), {})
370
+ r('mod', (np.arange(3), np.arange(3)), {})
371
+ r('modf', (np.arange(3),), {})
372
+ r('remainder', (np.arange(3), np.arange(3)), {})
373
+
374
+ hd3('Handling complex numbers')
375
+ m('angle')
376
+ m('real')
377
+ m('imag')
378
+ m('conj')
379
+
380
+ hd3('Miscellaneous')
381
+ r('convolve', (np.arange(3), np.arange(3)), {})
382
+ r('clip', (np.arange(3), 0, 2), {})
383
+ r('sqrt', (np.arange(3),), {})
384
+ r('square', (np.arange(3),), {})
385
+ r('absolute', (np.arange(3),), {})
386
+ r('fabs', (np.arange(3),), {})
387
+ r('sign', (np.arange(3),), {})
388
+ r('maximum', (np.arange(3), np.arange(3)), {})
389
+ r('minimum', (np.arange(3), np.arange(3)), {})
390
+ r('fmax', (np.arange(3), np.arange(3)), {})
391
+ r('fmin', (np.arange(3), np.arange(3)), {})
392
+ r('nan_to_num', (np.arange(3),), {})
393
+ r('real_if_close', (np.arange(3),), {})
394
+ r('interp', (2.5, [1,2,3], [3,2,0]), {})
395
+
396
+ extra_args = {'nplib': numpy.random, 'chlib': ch.random}
397
+
398
+ hd2('Random sampling (numpy.random)')
399
+ hd3('Simple random data')
400
+ r('rand', (3,), {}, **extra_args)
401
+ r('randn', (3,), {}, **extra_args)
402
+ r('randint', (3,), {}, **extra_args)
403
+ r('random_integers', (3,), {}, **extra_args)
404
+ r('random_sample', (3,), {}, **extra_args)
405
+ r('random', (3,), {}, **extra_args)
406
+ r('ranf', (3,), {}, **extra_args)
407
+ r('sample', (3,), {}, **extra_args)
408
+ r('choice', (np.ones(3),), {}, **extra_args)
409
+ r('bytes', (3,), {}, **extra_args)
410
+
411
+ hd3('Permutations')
412
+ r('shuffle', (np.ones(3),), {}, **extra_args)
413
+ r('permutation', (3,), {}, **extra_args)
414
+
415
+ hd3('Distributions (these all pass)')
416
+ r('beta', (.5, .5), {}, **extra_args)
417
+ r('binomial', (.5, .5), {}, **extra_args)
418
+ r('chisquare', (.5,), {}, **extra_args)
419
+ r('dirichlet', ((10, 5, 3), 20,), {}, **extra_args)
420
+ r('exponential', [], {}, **extra_args)
421
+ r('f', [1,48,1000], {}, **extra_args)
422
+ r('gamma', [.5], {}, **extra_args)
423
+ make_row('...AND 28 OTHERS...', 'passed', 'passed', 'lightgreen', 'lightgreen')
424
+
425
+
426
+ hd3('Random generator')
427
+ r('seed', [], {}, **extra_args)
428
+ r('get_state', [], {}, **extra_args)
429
+ r('set_state', [np.random.get_state()], {}, **extra_args)
430
+
431
+ ####################################
432
+ hd2('Statistics')
433
+ hd3('Order statistics')
434
+ r('amin', (np.eye(3),),{})
435
+ r('amax', (np.eye(3),),{})
436
+ r('nanmin', (np.eye(3),),{})
437
+ r('nanmax', (np.eye(3),),{})
438
+ r('ptp', (np.eye(3),),{})
439
+ r('percentile', (np.eye(3),50),{})
440
+
441
+ hd3('Averages and variance')
442
+ r('median', (np.eye(3),),{})
443
+ r('average', (np.eye(3),),{})
444
+ r('mean', (np.eye(3),),{})
445
+ r('std', (np.eye(3),),{})
446
+ r('var', (np.eye(3),),{})
447
+ r('nanmean', (np.eye(3),),{})
448
+ r('nanstd', (np.eye(3),),{})
449
+ r('nanvar', (np.eye(3),),{})
450
+
451
+
452
+ hd3('Correlating')
453
+ r('corrcoef', (np.eye(3),),{})
454
+ r('correlate', ([1, 2, 3], [0, 1, 0.5]),{})
455
+ r('cov', (np.eye(3),),{})
456
+
457
+ hd3('Histograms')
458
+ r('histogram', (np.eye(3),),{})
459
+ r('histogram2d', (np.eye(3).ravel(),np.eye(3).ravel()),{})
460
+ r('histogramdd', (np.eye(3).ravel(),),{})
461
+ r('bincount', (np.asarray(np.eye(3).ravel(), np.uint32),),{})
462
+ r('digitize', (np.array([0.2, 6.4, 3.0, 1.6]), np.array([0.0, 1.0, 2.5, 4.0, 10.0])),{})
463
+
464
+ ####################################
465
+ hd2('Sorting, searching, and counting')
466
+
467
+ hd3('Sorting')
468
+ r('sort', (np.array([1,3,1,2.]),), {})
469
+ m('lexsort')
470
+ m('argsort')
471
+ m('msort')
472
+ m('sort_complex')
473
+ m('partition')
474
+ m('argpartition')
475
+
476
+ # sort(a[, axis, kind, order]) Return a sorted copy of an array.
477
+ # lexsort(keys[, axis]) Perform an indirect sort using a sequence of keys.
478
+ # argsort(a[, axis, kind, order]) Returns the indices that would sort an array.
479
+ # ndarray.sort([axis, kind, order]) Sort an array, in-place.
480
+ # msort(a) Return a copy of an array sorted along the first axis.
481
+ # sort_complex(a) Sort a complex array using the real part first, then the imaginary part.
482
+ # partition(a, kth[, axis, kind, order]) Return a partitioned copy of an array.
483
+ # argpartition(a, kth[, axis, kind, order]) Perform an indirect partition along the given axis using the algorithm specified by the kind keyword.
484
+
485
+ a5 = np.arange(5)
486
+
487
+ hd3('Searching')
488
+ r('argmax', (a5,), {})
489
+ r('nanargmax', (a5,), {})
490
+ r('argmin', (a5,), {})
491
+ r('nanargmin', (a5,), {})
492
+ r('argwhere', (a5,), {})
493
+ r('nonzero', (a5,), {})
494
+ r('flatnonzero', (a5,), {})
495
+ r('where', (a5>1,), {})
496
+ r('searchsorted', (a5,a5), {})
497
+ r('extract', (lambda x : x > 1, a5), {})
498
+
499
+ # argmax(a[, axis]) Indices of the maximum values along an axis.
500
+ # nanargmax(a[, axis]) Return the indices of the maximum values in the specified axis ignoring
501
+ # argmin(a[, axis]) Return the indices of the minimum values along an axis.
502
+ # nanargmin(a[, axis]) Return the indices of the minimum values in the specified axis ignoring
503
+ # argwhere(a) Find the indices of array elements that are non-zero, grouped by element.
504
+ # nonzero(a) Return the indices of the elements that are non-zero.
505
+ # flatnonzero(a) Return indices that are non-zero in the flattened version of a.
506
+ # where(condition, [x, y]) Return elements, either from x or y, depending on condition.
507
+ # searchsorted(a, v[, side, sorter]) Find indices where elements should be inserted to maintain order.
508
+ # extract(condition, arr) Return the elements of an array that satisfy some condition.
509
+
510
+ hd3('Counting')
511
+ r('count_nonzero', (a5,), {})
512
+ #count_nonzero(a) Counts the number of non-zero values in the array a.
513
+
514
+
515
+
516
+ # histogram(a[, bins, range, normed, weights, ...]) Compute the histogram of a set of data.
517
+ # histogram2d(x, y[, bins, range, normed, weights]) Compute the bi-dimensional histogram of two data samples.
518
+ # histogramdd(sample[, bins, range, normed, ...]) Compute the multidimensional histogram of some data.
519
+ # bincount(x[, weights, minlength]) Count number of occurrences of each value in array of non-negative ints.
520
+ # digitize(x, bins[, right]) Return the indices of the bins to which each value in input array belongs.
521
+
522
+
523
+ global src
524
+ src = '<html><body><table border=1>' + src + '</table></body></html>'
525
+ open(join(split(__file__)[0], 'api_compatibility.html'), 'w').write(src)
526
+
527
+ print('passed %d, not passed %d' % (num_passed, num_not_passed))
528
+
529
+
530
+
531
+ if __name__ == '__main__':
532
+ global which_passed
533
+ main()
534
+ print(' '.join(which_passed))
vendor/chumpy/chumpy/ch.py ADDED
@@ -0,0 +1,1367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+
10
+ __all__ = ['Ch', 'depends_on', 'MatVecMult', 'ChHandle', 'ChLambda']
11
+
12
+ import os, sys, time
13
+ import inspect
14
+ import scipy.sparse as sp
15
+ import numpy as np
16
+ import numbers
17
+ import weakref
18
+ import copy as external_copy
19
+ from functools import wraps
20
+ from scipy.sparse.linalg.interface import LinearOperator
21
+ from .utils import row, col, timer, convert_inputs_to_sparse_if_necessary
22
+ import collections
23
+ from copy import deepcopy
24
+ from functools import reduce
25
+
26
+
27
+
28
+ # Turn this on if you want the profiler injected
29
+ DEBUG = False
30
+ # Turn this on to make optimizations very chatty for debugging
31
+ VERBOSE = False
32
+ def pif(msg):
33
+ # print-if-verbose.
34
+ if DEBUG or VERBOSE:
35
+ sys.stdout.write(msg + '\n')
36
+
37
+ _props_for_dict = weakref.WeakKeyDictionary()
38
+ def _props_for(cls):
39
+ if cls not in _props_for_dict:
40
+ _props_for_dict[cls] = set([p[0] for p in inspect.getmembers(cls, lambda x : isinstance(x, property))])
41
+ return _props_for_dict[cls]
42
+
43
+ _dep_props_for_dict = weakref.WeakKeyDictionary()
44
+ def _dep_props_for(cls):
45
+ if cls not in _dep_props_for_dict:
46
+ _dep_props_for_dict[cls] = [p for p in inspect.getmembers(cls, lambda x : isinstance(x, property)) if hasattr(p[1].fget, 'deps')]
47
+ return _dep_props_for_dict[cls]
48
+
49
+
50
+ _kw_conflict_dict = weakref.WeakKeyDictionary()
51
+ def _check_kw_conflict(cls):
52
+ if cls not in _kw_conflict_dict:
53
+ _kw_conflict_dict[cls] = Ch._reserved_kw.intersection(set(cls.terms).union(set(cls.dterms)))
54
+ if _kw_conflict_dict[cls]:
55
+ raise Exception("In class %s, don't use reserved keywords in terms/dterms: %s" % (str(cls), str(kw_conflict),))
56
+
57
+
58
+ class Term(object):
59
+ creation_counter = 0
60
+ def __init__(self, default=None, desc=None, dr=True):
61
+ self.default = default
62
+ self.desc = desc
63
+ self.dr = dr
64
+
65
+ # Add a creation_counter, a la Django models, so we can preserve the order in which parameters are defined in the job.
66
+ # http://stackoverflow.com/a/3288801/893113
67
+ self.creation_counter = Term.creation_counter
68
+ Term.creation_counter += 1
69
+
70
+
71
+ class Ch(object):
72
+ terms = []
73
+ dterms = ['x']
74
+ __array_priority__ = 2.0
75
+ _cached_parms = {}
76
+ _setup_terms = {}
77
+ _default_kwargs = {'make_dense' : False, 'make_sparse' : False}
78
+ _status = "undefined"
79
+
80
+ called_dr_wrt = False
81
+ profiler = None
82
+
83
+ ########################################################
84
+ # Construction
85
+
86
+ def __new__(cls, *args, **kwargs):
87
+
88
+ if len(args) > 0 and type(args[0]) == type(lambda : 0):
89
+ cls = ChLambda
90
+
91
+ # Create empty instance
92
+ result = super(Ch, cls).__new__(cls)
93
+
94
+ cls.setup_terms()
95
+
96
+ object.__setattr__(result, '_dirty_vars', set())
97
+ object.__setattr__(result, '_itr', None)
98
+ object.__setattr__(result, '_parents', weakref.WeakKeyDictionary())
99
+ object.__setattr__(result, '_cache', {'r': None, 'drs': weakref.WeakKeyDictionary()})
100
+
101
+ if DEBUG:
102
+ object.__setattr__(result, '_cache_info', {})
103
+ object.__setattr__(result, '_status', 'new')
104
+
105
+ for name, default_val in list(cls._default_kwargs.items()):
106
+ object.__setattr__(result, '_%s' % name, kwargs.get(name, default_val))
107
+ if name in kwargs:
108
+ del kwargs[name]
109
+
110
+ # Set up storage that allows @depends_on to work
111
+ #props = [p for p in inspect.getmembers(cls, lambda x : isinstance(x, property)) if hasattr(p[1].fget, 'deps')]
112
+ props = _dep_props_for(cls)
113
+ cpd = {}
114
+ for p in props:
115
+ func_name = p[0] #id(p[1].fget)
116
+ deps = p[1].fget.deps
117
+ cpd[func_name] = {'deps': deps, 'value': None, 'out_of_date': True}
118
+
119
+ object.__setattr__(result, '_depends_on_deps', cpd)
120
+
121
+ if cls != Ch:
122
+ for idx, a in enumerate(args):
123
+ kwargs[cls.term_order[idx]] = a
124
+ elif len(args)>0:
125
+ kwargs['x'] = np.asarray(args[0], np.float64)
126
+
127
+ defs = {p.name : deepcopy(p.default) for p in cls.parm_declarations() if p.default is not None}
128
+ defs.update(kwargs)
129
+ result.set(**defs)
130
+
131
+ return result
132
+
133
+ @classmethod
134
+ def parm_declarations(cls):
135
+ if cls.__name__ not in cls._cached_parms:
136
+ parameter_declarations = collections.OrderedDict()
137
+ parameters = inspect.getmembers(cls, lambda x: isinstance(x, Term))
138
+ for name, decl in sorted(parameters, key=lambda x: x[1].creation_counter):
139
+ decl.name = name
140
+ parameter_declarations[name] = decl
141
+ cls._cached_parms[cls.__name__] = parameter_declarations
142
+ return cls._cached_parms[cls.__name__]
143
+
144
+ @classmethod
145
+ def setup_terms(cls):
146
+ if id(cls) in cls._setup_terms: return
147
+
148
+ if cls == Ch:
149
+ return
150
+
151
+ parm_declarations = cls.parm_declarations()
152
+
153
+ if cls.dterms is Ch.dterms:
154
+ cls.dterms = []
155
+ elif isinstance(cls.dterms, str):
156
+ cls.dterms = (cls.dterms,)
157
+ if cls.terms is Ch.terms:
158
+ cls.terms = []
159
+ elif isinstance(cls.terms, str):
160
+ cls.terms = (cls.terms,)
161
+
162
+ # Must be either new or old style
163
+ len_oldstyle_parms = len(cls.dterms)+len(cls.terms)
164
+ if len(parm_declarations) > 0:
165
+ assert(len_oldstyle_parms==0)
166
+ cls.term_order = [t.name for t in parm_declarations]
167
+ cls.dterms = [t.name for t in parm_declarations if t.dr]
168
+ cls.terms = [t.name for t in parm_declarations if not t.dr]
169
+ else:
170
+ if not hasattr(cls, 'term_order'):
171
+ cls.term_order = list(cls.terms) + list(cls.dterms)
172
+
173
+ _check_kw_conflict(cls)
174
+ cls._setup_terms[id(cls)] = True
175
+
176
+
177
+ ########################################################
178
+ # Identifiers
179
+
180
+ @property
181
+ def short_name(self):
182
+ return self.label if hasattr(self, 'label') else self.__class__.__name__
183
+
184
+ @property
185
+ def sid(self):
186
+ """Semantic id."""
187
+ pnames = list(self.terms)+list(self.dterms)
188
+ pnames.sort()
189
+ return (self.__class__, tuple([(k, id(self.__dict__[k])) for k in pnames if k in self.__dict__]))
190
+
191
+
192
+ def reshape(self, *args):
193
+ return reshape(a=self, newshape=args if len(args)>1 else args[0])
194
+
195
+ def ravel(self):
196
+ return reshape(a=self, newshape=(-1))
197
+
198
+ def __hash__(self):
199
+ return id(self)
200
+
201
+ @property
202
+ def ndim(self):
203
+ return self.r.ndim
204
+
205
+ @property
206
+ def flat(self):
207
+ return self.r.flat
208
+
209
+ @property
210
+ def dtype(self):
211
+ return self.r.dtype
212
+
213
+ @property
214
+ def itemsize(self):
215
+ return self.r.itemsize
216
+
217
+
218
+ ########################################################
219
+ # Redundancy removal
220
+
221
+ def remove_redundancy(self, cache=None, iterate=True):
222
+
223
+ if cache == None:
224
+ cache = {}
225
+ _ = self.r # may result in the creation of extra dterms that we can cull
226
+
227
+ replacement_occurred = False
228
+ for propname in list(self.dterms):
229
+ prop = self.__dict__[propname]
230
+
231
+ if not hasattr(prop, 'dterms'):
232
+ continue
233
+ sid = prop.sid
234
+ if sid not in cache:
235
+ cache[sid] = prop
236
+ elif self.__dict__[propname] is not cache[sid]:
237
+ self.__dict__[propname] = cache[sid]
238
+ replacement_occurred = True
239
+ if prop.remove_redundancy(cache, iterate=False):
240
+ replacement_occurred = True
241
+
242
+ if not replacement_occurred:
243
+ return False
244
+ else:
245
+ if iterate:
246
+ self.remove_redundancy(cache, iterate=True)
247
+ return False
248
+ else:
249
+ return True
250
+
251
+
252
+
253
+ def print_labeled_residuals(self, print_newline=True, num_decimals=2, where_to_print=None):
254
+
255
+ if where_to_print is None:
256
+ where_to_print = sys.stderr
257
+ if hasattr(self, 'label'):
258
+ where_to_print.write(('%s: %.' + str(num_decimals) + 'e | ') % (self.label, np.sum(self.r**2)))
259
+ for dterm in self.dterms:
260
+ dt = getattr(self, dterm)
261
+ if hasattr(dt, 'dterms'):
262
+ dt.print_labeled_residuals(print_newline=False, where_to_print=where_to_print)
263
+ if print_newline:
264
+ where_to_print.write(('%.' + str(num_decimals) + 'e\n') % (np.sum(self.r**2),))
265
+
266
+
267
+
268
+ ########################################################
269
+ # Default methods, for when Ch is not subclassed
270
+
271
+ def compute_r(self):
272
+ """Default method for objects that just contain a number or ndarray"""
273
+ return self.x
274
+
275
+ def compute_dr_wrt(self,wrt):
276
+ """Default method for objects that just contain a number or ndarray"""
277
+ if wrt is self: # special base case
278
+ return sp.eye(self.x.size, self.x.size)
279
+ #return np.array([[1]])
280
+ return None
281
+
282
+
283
+ def _compute_dr_wrt_sliced(self, wrt):
284
+ self._call_on_changed()
285
+
286
+ # if wrt is self:
287
+ # return np.array([[1]])
288
+
289
+ result = self.compute_dr_wrt(wrt)
290
+ if result is not None:
291
+ return result
292
+
293
+ # What allows slicing.
294
+ if True:
295
+ inner = wrt
296
+ while issubclass(inner.__class__, Permute):
297
+ inner = inner.a
298
+ if inner is self:
299
+ return None
300
+ result = self.compute_dr_wrt(inner)
301
+
302
+ if result is not None:
303
+ break
304
+
305
+ if result is None:
306
+ return None
307
+
308
+ wrt._call_on_changed()
309
+
310
+ jac = wrt.compute_dr_wrt(inner).T
311
+
312
+ return self._superdot(result, jac)
313
+
314
+
315
+ @property
316
+ def shape(self):
317
+ return self.r.shape
318
+
319
+ @property
320
+ def size(self):
321
+ #return self.r.size
322
+ return np.prod(self.shape) # may be cheaper since it doesn't always mean grabbing "r"
323
+
324
+ def __len__(self):
325
+ return len(self.r)
326
+
327
+ def minimize(self, *args, **kwargs):
328
+ from . import optimization
329
+ return optimization.minimize(self, *args, **kwargs)
330
+
331
+ def __array__(self, *args):
332
+ return self.r
333
+
334
+ ########################################################
335
+ # State management
336
+
337
+ def add_dterm(self, dterm_name, dterm):
338
+ self.dterms = list(set(list(self.dterms) + [dterm_name]))
339
+ setattr(self, dterm_name, dterm)
340
+
341
+ def copy(self):
342
+ return copy(self)
343
+
344
+ def __getstate__(self):
345
+ # Have to get rid of WeakKeyDictionaries for serialization
346
+ result = external_copy.copy(self.__dict__)
347
+ del result['_parents']
348
+ del result['_cache']
349
+ return result
350
+
351
+ def __setstate__(self, d):
352
+ # Restore unpickleable WeakKeyDictionaries
353
+ d['_parents'] = weakref.WeakKeyDictionary()
354
+ d['_cache'] = {'r': None, 'drs': weakref.WeakKeyDictionary()}
355
+ object.__setattr__(self, '__dict__', d)
356
+
357
+ # This restores our unpickleable "_parents" attribute
358
+ for k in set(self.dterms).intersection(set(self.__dict__.keys())):
359
+ setattr(self, k, self.__dict__[k])
360
+
361
+ def __setattr__(self, name, value, itr=None):
362
+ #print 'SETTING %s' % (name,)
363
+
364
+ # Faster path for basic Ch objects. Not necessary for functionality,
365
+ # but improves performance by a small amount.
366
+ if type(self) == Ch:
367
+ if name == 'x':
368
+ self._dirty_vars.add(name)
369
+ self.clear_cache(itr)
370
+ #else:
371
+ # import warnings
372
+ # warnings.warn('Trying to set attribute %s on a basic Ch object? Might be a mistake.' % (name,))
373
+
374
+ object.__setattr__(self, name, value)
375
+ return
376
+
377
+ name_in_dterms = name in self.dterms
378
+ name_in_terms = name in self.terms
379
+ name_in_props = name in _props_for(self.__class__)# [p[0] for p in inspect.getmembers(self.__class__, lambda x : isinstance(x, property))]
380
+
381
+ if name_in_dterms and not name_in_props and type(self) != Ch:
382
+ if not hasattr(value, 'dterms'):
383
+ value = Ch(value)
384
+
385
+ # Make ourselves not the parent of the old value
386
+ if hasattr(self, name):
387
+ term = getattr(self, name)
388
+ if self in term._parents:
389
+ term._parents[self]['varnames'].remove(name)
390
+ if len(term._parents[self]['varnames']) == 0:
391
+ del term._parents[self]
392
+
393
+ # Make ourselves parents of the new value
394
+ if self not in value._parents:
395
+ value._parents[self] = {'varnames': set([name])}
396
+ else:
397
+ value._parents[self]['varnames'].add(name)
398
+
399
+ if name_in_dterms or name_in_terms:
400
+ self._dirty_vars.add(name)
401
+ self._invalidate_cacheprop_names([name])
402
+
403
+ # If one of our terms has changed, it has the capacity to have
404
+ # changed our result and all our derivatives wrt everything
405
+ self.clear_cache(itr)
406
+
407
+ object.__setattr__(self, name, value)
408
+
409
+ def _invalidate_cacheprop_names(self, names):
410
+ nameset = set(names)
411
+ for func_name, v in list(self._depends_on_deps.items()):
412
+ if len(nameset.intersection(v['deps'])) > 0:
413
+ v['out_of_date'] = True
414
+
415
+
416
+ def clear_cache(self, itr=None):
417
+ todo = [self]
418
+ done = set([])
419
+ nodes_visited = 0
420
+ while len(todo) > 0:
421
+ nodes_visited += 1
422
+ next = todo.pop()
423
+ if itr is not None and itr==next._itr:
424
+ continue
425
+ if id(next) not in done:
426
+ next._cache['r'] = None
427
+ next._cache['drs'].clear()
428
+ next._itr = itr
429
+
430
+ for parent, parent_dict in list(next._parents.items()):
431
+ object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
432
+ parent._invalidate_cacheprop_names(parent_dict['varnames'])
433
+ todo.append(parent)
434
+ done.add(id(next))
435
+ return nodes_visited
436
+
437
+
438
+ def clear_cache_wrt(self, wrt, itr=None):
439
+ if wrt in self._cache['drs']:
440
+ self._cache['drs'][wrt] = None
441
+
442
+ if hasattr(self, 'dr_cached') and wrt in self.dr_cached:
443
+ self.dr_cached[wrt] = None
444
+
445
+ if itr is None or itr != self._itr:
446
+ for parent, parent_dict in list(self._parents.items()):
447
+ if wrt in parent._cache['drs'] or (hasattr(parent, 'dr_cached') and wrt in parent.dr_cached):
448
+ parent.clear_cache_wrt(wrt=wrt, itr=itr)
449
+ object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
450
+ parent._invalidate_cacheprop_names(parent_dict['varnames'])
451
+
452
+ object.__setattr__(self, '_itr', itr)
453
+
454
+ def replace(self, old, new):
455
+ if (hasattr(old, 'dterms') != hasattr(new, 'dterms')):
456
+ raise Exception('Either "old" and "new" must both be "Ch", or they must both be neither.')
457
+
458
+ for term_name in [t for t in list(self.dterms)+list(self.terms) if hasattr(self, t)]:
459
+ term = getattr(self, term_name)
460
+ if term is old:
461
+ setattr(self, term_name, new)
462
+ elif hasattr(term, 'dterms'):
463
+ term.replace(old, new)
464
+ return new
465
+
466
+
467
+ def set(self, **kwargs):
468
+ # Some dterms may be aliases via @property.
469
+ # We want to set those last, in case they access non-property members
470
+ #props = [p[0] for p in inspect.getmembers(self.__class__, lambda x : isinstance(x, property))]
471
+ props = _props_for(self.__class__)
472
+ kwarg_keys = set(kwargs.keys())
473
+ kwsecond = kwarg_keys.intersection(props)
474
+ kwfirst = kwarg_keys.difference(kwsecond)
475
+ kwall = list(kwfirst) + list(kwsecond)
476
+
477
+ # The complexity here comes because we wish to
478
+ # avoid clearing cache redundantly
479
+ if len(kwall) > 0:
480
+ for k in kwall[:-1]:
481
+ self.__setattr__(k, kwargs[k], 9999)
482
+ self.__setattr__(kwall[-1], kwargs[kwall[-1]], None)
483
+
484
+
485
+ def is_dr_wrt(self, wrt):
486
+ if type(self) == Ch:
487
+ return wrt is self
488
+ dterms_we_have = [getattr(self, dterm) for dterm in self.dterms if hasattr(self, dterm)]
489
+ return wrt in dterms_we_have or any([d.is_dr_wrt(wrt) for d in dterms_we_have])
490
+
491
+
492
+ def is_ch_baseclass(self):
493
+ return self.__class__ is Ch
494
+
495
+
496
+ ########################################################
497
+ # Getters for our outputs
498
+
499
+ def __getitem__(self, key):
500
+ shape = self.shape
501
+ tmp = np.arange(np.prod(shape)).reshape(shape).__getitem__(key)
502
+ idxs = tmp.ravel()
503
+ newshape = tmp.shape
504
+ return Select(a=self, idxs=idxs, preferred_shape=newshape)
505
+
506
+ def __setitem__(self, key, value, itr=None):
507
+
508
+ if hasattr(value, 'dterms'):
509
+ raise Exception("Can't assign a Ch objects as a subset of another.")
510
+ if type(self) == Ch:# self.is_ch_baseclass():
511
+ data = np.atleast_1d(self.x)
512
+ data.__setitem__(key, value)
513
+ self.__setattr__('x', data, itr=itr)
514
+ return
515
+ # elif False: # Interesting but flawed idea
516
+ # parents = [self.__dict__[k] for k in self.dterms]
517
+ # kids = []
518
+ # while len(parents)>0:
519
+ # p = parents.pop()
520
+ # if p.is_ch_baseclass():
521
+ # kids.append(p)
522
+ # else:
523
+ # parents += [p.__dict__[k] for k in p.dterms]
524
+ # from ch.optimization import minimize_dogleg
525
+ # minimize_dogleg(obj=self.__getitem__(key) - value, free_variables=kids, show_residuals=False)
526
+ else:
527
+ inner = self
528
+ while not inner.is_ch_baseclass():
529
+ if issubclass(inner.__class__, Permute):
530
+ inner = inner.a
531
+ else:
532
+ raise Exception("Can't set array that is function of arrays.")
533
+
534
+ self = self[key]
535
+ dr = self.dr_wrt(inner)
536
+ dr_rev = dr.T
537
+ #dr_rev = np.linalg.pinv(dr)
538
+ inner_shape = inner.shape
539
+
540
+ t1 = self._superdot(dr_rev, np.asarray(value).ravel())
541
+ t2 = self._superdot(dr_rev, self._superdot(dr, inner.x.ravel()))
542
+ if sp.issparse(t1): t1 = np.array(t1.todense())
543
+ if sp.issparse(t2): t2 = np.array(t2.todense())
544
+
545
+ inner.x = inner.x + t1.reshape(inner_shape) - t2.reshape(inner_shape)
546
+ #inner.x = inner.x + self._superdot(dr_rev, value.ravel()).reshape(inner_shape) - self._superdot(dr_rev, self._superdot(dr, inner.x.ravel())).reshape(inner_shape)
547
+
548
+
549
+ def __str__(self):
550
+ return str(self.r)
551
+
552
+ def __repr__(self):
553
+ return object.__repr__(self) + '\n' + str(self.r)
554
+
555
+ def __float__(self):
556
+ return self.r.__float__()
557
+
558
+ def __int__(self):
559
+ return self.r.__int__()
560
+
561
+ def on_changed(self, terms):
562
+ pass
563
+
564
+ @property
565
+ def T(self):
566
+ return transpose(self)
567
+
568
+ def transpose(self, *axes):
569
+ return transpose(self, *axes)
570
+
571
+ def squeeze(self, axis=None):
572
+ return squeeze(self, axis)
573
+
574
+ def mean(self, axis=None):
575
+ return mean(self, axis=axis)
576
+
577
+ def sum(self, axis=None):
578
+ return sum(self, axis=axis)
579
+
580
+ def _call_on_changed(self):
581
+
582
+ if hasattr(self, 'is_valid'):
583
+ validity, msg = self.is_valid()
584
+ assert validity, msg
585
+ if hasattr(self, '_status'):
586
+ self._status = 'new'
587
+
588
+ if len(self._dirty_vars) > 0:
589
+ self.on_changed(self._dirty_vars)
590
+ object.__setattr__(self, '_dirty_vars', set())
591
+
592
+ @property
593
+ def r(self):
594
+ self._call_on_changed()
595
+ if self._cache['r'] is None:
596
+ self._cache['r'] = np.asarray(np.atleast_1d(self.compute_r()), dtype=np.float64, order='C')
597
+ self._cache['rview'] = self._cache['r'].view()
598
+ self._cache['rview'].flags.writeable = False
599
+
600
+ return self._cache['rview']
601
+
602
+ def _superdot(self, lhs, rhs, profiler=None):
603
+
604
+ try:
605
+ if lhs is None:
606
+ return None
607
+ if rhs is None:
608
+ return None
609
+
610
+ if isinstance(lhs, np.ndarray) and lhs.size==1:
611
+ lhs = lhs.ravel()[0]
612
+
613
+ if isinstance(rhs, np.ndarray) and rhs.size==1:
614
+ rhs = rhs.ravel()[0]
615
+
616
+ if isinstance(lhs, numbers.Number) or isinstance(rhs, numbers.Number):
617
+ return lhs * rhs
618
+
619
+ if isinstance(rhs, LinearOperator):
620
+ return LinearOperator((lhs.shape[0], rhs.shape[1]), lambda x : lhs.dot(rhs.dot(x)))
621
+
622
+ if isinstance(lhs, LinearOperator):
623
+ if sp.issparse(rhs):
624
+ return LinearOperator((lhs.shape[0], rhs.shape[1]), lambda x : lhs.dot(rhs.dot(x)))
625
+ else:
626
+ # TODO: ?????????????
627
+ # return lhs.matmat(rhs)
628
+ return lhs.dot(rhs)
629
+
630
+ # TODO: Figure out how/whether to do this.
631
+ tm_maybe_sparse = timer()
632
+ lhs, rhs = convert_inputs_to_sparse_if_necessary(lhs, rhs)
633
+ if tm_maybe_sparse() > 0.1:
634
+ pif('convert_inputs_to_sparse_if_necessary in {}sec'.format(tm_maybe_sparse()))
635
+
636
+ if not sp.issparse(lhs) and sp.issparse(rhs):
637
+ return rhs.T.dot(lhs.T).T
638
+ return lhs.dot(rhs)
639
+ except Exception as e:
640
+ import sys, traceback
641
+ traceback.print_exc(file=sys.stdout)
642
+ if DEBUG:
643
+ import pdb; pdb.post_mortem()
644
+ else:
645
+ raise
646
+
647
+ def lmult_wrt(self, lhs, wrt):
648
+ if lhs is None:
649
+ return None
650
+
651
+ self._call_on_changed()
652
+
653
+ drs = []
654
+
655
+ direct_dr = self._compute_dr_wrt_sliced(wrt)
656
+
657
+ if direct_dr != None:
658
+ drs.append(self._superdot(lhs, direct_dr))
659
+
660
+ for k in set(self.dterms):
661
+ p = self.__dict__[k]
662
+
663
+ if hasattr(p, 'dterms') and p is not wrt and p.is_dr_wrt(wrt):
664
+ if not isinstance(p, Ch):
665
+ print('BROKEN!')
666
+ raise Exception('Broken Should be Ch object')
667
+
668
+ indirect_dr = p.lmult_wrt(self._superdot(lhs, self._compute_dr_wrt_sliced(p)), wrt)
669
+ if indirect_dr is not None:
670
+ drs.append(indirect_dr)
671
+
672
+ if len(drs)==0:
673
+ result = None
674
+
675
+ elif len(drs)==1:
676
+ result = drs[0]
677
+
678
+ else:
679
+ result = reduce(lambda x, y: x+y, drs)
680
+
681
+ return result
682
+
683
+
684
+ def compute_lop(self, wrt, lhs):
685
+ dr = self._compute_dr_wrt_sliced(wrt)
686
+ if dr is None: return None
687
+ return self._superdot(lhs, dr) if not isinstance(lhs, LinearOperator) else lhs.matmat(dr)
688
+
689
+
690
+ def lop(self, wrt, lhs):
691
+ self._call_on_changed()
692
+
693
+ drs = []
694
+ direct_dr = self.compute_lop(wrt, lhs)
695
+ if direct_dr is not None:
696
+ drs.append(direct_dr)
697
+
698
+ for k in set(self.dterms):
699
+ p = getattr(self, k) # self.__dict__[k]
700
+ if hasattr(p, 'dterms') and p is not wrt: # and p.is_dr_wrt(wrt):
701
+ lhs_for_child = self.compute_lop(p, lhs)
702
+ if lhs_for_child is not None: # Can be None with ChLambda, _result etc
703
+ indirect_dr = p.lop(wrt, lhs_for_child)
704
+ if indirect_dr is not None:
705
+ drs.append(indirect_dr)
706
+
707
+ for k in range(len(drs)):
708
+ if sp.issparse(drs[k]):
709
+ drs[k] = drs[k].todense()
710
+
711
+ if len(drs)==0:
712
+ result = None
713
+
714
+ elif len(drs)==1:
715
+ result = drs[0]
716
+
717
+ else:
718
+ result = reduce(lambda x, y: x+y, drs)
719
+
720
+
721
+ return result
722
+
723
+ def compute_rop(self, wrt, rhs):
724
+ dr = self._compute_dr_wrt_sliced(wrt)
725
+ if dr is None: return None
726
+
727
+ return self._superdot(dr, rhs)
728
+
729
+ def dr_wrt(self, wrt, reverse_mode=False, profiler=None):
730
+ tm_dr_wrt = timer()
731
+ self.called_dr_wrt = True
732
+ self._call_on_changed()
733
+
734
+ drs = []
735
+
736
+ if wrt in self._cache['drs']:
737
+ if DEBUG:
738
+ if wrt not in self._cache_info:
739
+ self._cache_info[wrt] = 0
740
+ self._cache_info[wrt] +=1
741
+ self._status = 'cached'
742
+ return self._cache['drs'][wrt]
743
+
744
+ direct_dr = self._compute_dr_wrt_sliced(wrt)
745
+
746
+ if direct_dr is not None:
747
+ drs.append(direct_dr)
748
+
749
+ if DEBUG:
750
+ self._status = 'pending'
751
+
752
+ propnames = set(_props_for(self.__class__))
753
+ for k in set(self.dterms).intersection(propnames.union(set(self.__dict__.keys()))):
754
+
755
+ p = getattr(self, k)
756
+
757
+ if hasattr(p, 'dterms') and p is not wrt:
758
+
759
+ indirect_dr = None
760
+
761
+ if reverse_mode:
762
+ lhs = self._compute_dr_wrt_sliced(p)
763
+ if isinstance(lhs, LinearOperator):
764
+ tm_dr_wrt.pause()
765
+ dr2 = p.dr_wrt(wrt)
766
+ tm_dr_wrt.resume()
767
+ indirect_dr = lhs.matmat(dr2) if dr2 != None else None
768
+ else:
769
+ indirect_dr = p.lmult_wrt(lhs, wrt)
770
+ else: # forward mode
771
+ tm_dr_wrt.pause()
772
+ dr2 = p.dr_wrt(wrt, profiler=profiler)
773
+ tm_dr_wrt.resume()
774
+ if dr2 is not None:
775
+ indirect_dr = self.compute_rop(p, rhs=dr2)
776
+
777
+ if indirect_dr is not None:
778
+ drs.append(indirect_dr)
779
+
780
+ if len(drs)==0:
781
+ result = None
782
+ elif len(drs)==1:
783
+ result = drs[0]
784
+ else:
785
+ # TODO: ????????
786
+ # result = np.sum(x for x in drs)
787
+ if not np.any([isinstance(a, LinearOperator) for a in drs]):
788
+ result = reduce(lambda x, y: x+y, drs)
789
+ else:
790
+ result = LinearOperator(drs[0].shape, lambda x : reduce(lambda a, b: a.dot(x)+b.dot(x),drs))
791
+
792
+ # TODO: figure out how/whether to do this.
793
+ if result is not None and not sp.issparse(result):
794
+ tm_nonzero = timer()
795
+ nonzero = np.count_nonzero(result)
796
+ if tm_nonzero() > 0.1:
797
+ pif('count_nonzero in {}sec'.format(tm_nonzero()))
798
+ if nonzero == 0 or hasattr(result, 'size') and result.size / float(nonzero) >= 10.0:
799
+ tm_convert_to_sparse = timer()
800
+ result = sp.csc_matrix(result)
801
+ import gc
802
+ gc.collect()
803
+ pif('converting result to sparse in {}sec'.format(tm_convert_to_sparse()))
804
+
805
+ if (result is not None) and (not sp.issparse(result)) and (not isinstance(result, LinearOperator)):
806
+ result = np.atleast_2d(result)
807
+
808
+ # When the number of parents is one, it indicates that
809
+ # caching this is probably not useful because not
810
+ # more than one parent will likely ask for this same
811
+ # thing again in the same iteration of an optimization.
812
+ #
813
+ # When the number of parents is zero, this is the top
814
+ # level object and should be cached; when it's > 1
815
+ # cache the combinations of the children.
816
+ #
817
+ # If we *always* filled in the cache, it would require
818
+ # more memory but would occasionally save a little cpu,
819
+ # on average.
820
+ if len(list(self._parents.keys())) != 1:
821
+ self._cache['drs'][wrt] = result
822
+
823
+ if DEBUG:
824
+ self._status = 'done'
825
+
826
+ if getattr(self, '_make_dense', False) and sp.issparse(result):
827
+ result = result.todense()
828
+ if getattr(self, '_make_sparse', False) and not sp.issparse(result):
829
+ result = sp.csc_matrix(result)
830
+
831
+ if tm_dr_wrt() > 0.1:
832
+ pif('dx of {} wrt {} in {}sec, sparse: {}'.format(self.short_name, wrt.short_name, tm_dr_wrt(), sp.issparse(result)))
833
+
834
+ return result
835
+
836
+
837
+ def __call__(self, **kwargs):
838
+ self.set(**kwargs)
839
+ return self.r
840
+
841
+
842
+ ########################################################
843
+ # Visualization
844
+
845
+ @property
846
+ def reset_flag(self):
847
+ """
848
+ Used as fn in loop_children_do
849
+ """
850
+ return lambda x: setattr(x, 'called_dr_wrt', False)
851
+
852
+ def loop_children_do(self, fn):
853
+ fn(self)
854
+ for dterm in self.dterms:
855
+ if hasattr(self, dterm):
856
+ dtval = getattr(self, dterm)
857
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
858
+ if hasattr(dtval, 'loop_children_do'):
859
+ dtval.loop_children_do(fn)
860
+
861
+
862
+ def show_tree_cache(self, label, current_node=None):
863
+ '''
864
+ Show tree and cache info with color represent _status
865
+ Optionally accpet current_node arg to highlight the current node we are in
866
+ '''
867
+ import os
868
+ import tempfile
869
+ import subprocess
870
+
871
+ assert DEBUG, "Please use dr tree visualization functions in debug mode"
872
+
873
+ cache_path = os.path.abspath('profiles')
874
+ def string_for(self, my_name):
875
+
876
+ color_mapping = {'new' : 'grey', 'pending':'red', 'cached':'yellow', 'done': 'green'}
877
+ if hasattr(self, 'label'):
878
+ my_name = self.label
879
+ my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
880
+ result = []
881
+ if not hasattr(self, 'dterms'):
882
+ return result
883
+ for dterm in self.dterms:
884
+ if hasattr(self, dterm):
885
+ dtval = getattr(self, dterm)
886
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
887
+ child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
888
+ child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
889
+ src = 'aaa%d' % (id(self))
890
+ dst = 'aaa%d' % (id(dtval))
891
+
892
+ s = ''
893
+ color = color_mapping[dtval._status] if hasattr(dtval, '_status') else 'grey'
894
+ if dtval == current_node:
895
+ color = 'blue'
896
+ if isinstance(dtval, Concatenate) and len(dtval.dr_cached) > 0:
897
+ s = 'dr_cached\n'
898
+ for k, v in dtval.dr_cached.items():
899
+ if v is not None:
900
+ issparse = sp.issparse(v)
901
+ size = v.size
902
+ if issparse:
903
+ size = v.shape[0] * v.shape[1]
904
+ nonzero = len(v.data)
905
+ else:
906
+ nonzero = np.count_nonzero(v)
907
+ s += '\nsparse: %s\nsize: %d\nnonzero: %d\n' % (issparse, size, nonzero)
908
+ # if dtval.called_dr_wrt:
909
+ # # dtval.called_dr_wrt = False
910
+ # color = 'brown3'
911
+ # else:
912
+ # color = 'azure1'
913
+ elif len(dtval._cache['drs']) > 0:
914
+ s = '_cache\n'
915
+
916
+ for k, v in dtval._cache['drs'].items():
917
+ if v is not None:
918
+ issparse = sp.issparse(v)
919
+ size = v.size
920
+ if issparse:
921
+ size = v.shape[0] * v.shape[1]
922
+ nonzero = len(v.data)
923
+ else:
924
+ nonzero = np.count_nonzero(v)
925
+
926
+ s += '\nsparse: %s\nsize: %d\nnonzero: %d\n' % (issparse, size, nonzero)
927
+ if hasattr(dtval, '_cache_info'):
928
+ s += '\ncache hit:%s\n' % dtval._cache_info[k]
929
+ # if hasattr(dtval,'called_dr_wrt') and dtval.called_dr_wrt:
930
+ # # dtval.called_dr_wrt = False
931
+ # color = 'brown3'
932
+ # else:
933
+ # color = 'azure1'
934
+ result += ['%s -> %s;' % (src, dst)]
935
+ # Do not overwrite src
936
+ #result += ['%s [label="%s"];' % (src, my_name)]
937
+ result += ['%s [label="%s\n%s\n", color=%s, style=filled];' %
938
+ (dst, child_label, s, color)]
939
+ result += string_for(getattr(self, dterm), dterm)
940
+ return result
941
+
942
+
943
+ dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root'))))
944
+ dot_file_name = os.path.join(cache_path, label)
945
+ png_file_name = os.path.join(cache_path, label+'.png')
946
+ with open(dot_file_name, 'w') as dot_file:
947
+ with open(png_file_name, 'w') as png_file:
948
+ dot_file.write(dot_file_contents)
949
+ dot_file.flush()
950
+
951
+ png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
952
+ subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
953
+
954
+ import webbrowser
955
+ webbrowser.open('file://' + png_file.name)
956
+
957
+ self.loop_children_do(self.reset_flag)
958
+
959
+ def show_tree_wrt(self, wrt):
960
+ import tempfile
961
+ import subprocess
962
+
963
+ assert DEBUG, "Please use dr tree visualization functions in debug mode"
964
+
965
+ def string_for(self, my_name, wrt):
966
+ if hasattr(self, 'label'):
967
+ my_name = self.label
968
+ my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
969
+ result = []
970
+ if not hasattr(self, 'dterms'):
971
+ return result
972
+ for dterm in self.dterms:
973
+ if hasattr(self, dterm):
974
+ dtval = getattr(self, dterm)
975
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
976
+ child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
977
+ child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
978
+ src = 'aaa%d' % (id(self))
979
+ dst = 'aaa%d' % (id(dtval))
980
+ result += ['%s -> %s;' % (src, dst)]
981
+ result += ['%s [label="%s"];' % (src, my_name)]
982
+ if wrt in dtval._cache['drs'] and dtval._cache['drs'][wrt] is not None:
983
+ issparse = sp.issparse(dtval._cache['drs'][wrt])
984
+ size = dtval._cache['drs'][wrt].size
985
+ nonzero = np.count_nonzero(dtval._cache['drs'][wrt])
986
+ result += ['%s [label="%s\n is_sparse: %s\nsize: %d\nnonzero: %d"];' %
987
+ (dst, child_label, issparse, size,
988
+ nonzero)]
989
+ else:
990
+ result += ['%s [label="%s"];' % (dst, child_label)]
991
+ result += string_for(getattr(self, dterm), dterm, wrt)
992
+ return result
993
+
994
+
995
+ dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root', wrt))))
996
+ dot_file = tempfile.NamedTemporaryFile()
997
+ dot_file.write(dot_file_contents)
998
+ dot_file.flush()
999
+ png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
1000
+ subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
1001
+ import webbrowser
1002
+ webbrowser.open('file://' + png_file.name)
1003
+
1004
+ def show_tree(self, cachelim=np.inf):
1005
+ """Cachelim is in Mb. For any cached jacobians above cachelim, they are also added to the graph. """
1006
+ import tempfile
1007
+ import subprocess
1008
+
1009
+ assert DEBUG, "Please use dr tree visualization functions in debug mode"
1010
+
1011
+ def string_for(self, my_name):
1012
+ if hasattr(self, 'label'):
1013
+ my_name = self.label
1014
+ my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
1015
+ result = []
1016
+ if not hasattr(self, 'dterms'):
1017
+ return result
1018
+ for dterm in self.dterms:
1019
+ if hasattr(self, dterm):
1020
+ dtval = getattr(self, dterm)
1021
+ if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
1022
+ child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
1023
+ child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
1024
+ src = 'aaa%d' % (id(self))
1025
+ dst = 'aaa%d' % (id(dtval))
1026
+ result += ['%s -> %s;' % (src, dst)]
1027
+ result += ['%s [label="%s"];' % (src, my_name)]
1028
+ result += ['%s [label="%s"];' % (dst, child_label)]
1029
+ result += string_for(getattr(self, dterm), dterm)
1030
+
1031
+ if cachelim != np.inf and hasattr(self, '_cache') and 'drs' in self._cache:
1032
+ from six.moves import cPickle as pickle
1033
+ for dtval, jac in list(self._cache['drs'].items()):
1034
+ # child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
1035
+ # child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
1036
+ src = 'aaa%d' % (id(self))
1037
+ dst = 'aaa%d' % (id(dtval))
1038
+ try:
1039
+ sz = sys.getsizeof(pickle.dumps(jac, -1))
1040
+ except: # some are functions
1041
+ sz = 0
1042
+ # colorattr = "#%02x%02x%02x" % (szpct*255, 0, (1-szpct)*255)
1043
+ # print colorattr
1044
+ if sz > (cachelim * 1024 * 1024):
1045
+ result += ['%s -> %s [style=dotted,color="<<<%d>>>"];' % (src, dst, sz)]
1046
+ #
1047
+ # result += ['%s -> %s [style=dotted];' % (src, dst)]
1048
+ # result += ['%s [label="%s"];' % (src, my_name)]
1049
+ # result += ['%s [label="%s"];' % (dst, child_label)]
1050
+ # result += string_for(getattr(self, dterm), dterm)
1051
+
1052
+ return result
1053
+
1054
+ dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root'))))
1055
+ if cachelim != np.inf:
1056
+ import re
1057
+ strs = re.findall(r'<<<(\d+)>>>', dot_file_contents, re.DOTALL)
1058
+ if len(strs) > 0:
1059
+ the_max = np.max(np.array([int(d) for d in strs]))
1060
+ for s in strs:
1061
+ szpct = float(s)/the_max
1062
+ sz = float(s)
1063
+ unit = 'b'
1064
+ if sz > 1024.:
1065
+ sz /= 1024
1066
+ unit = 'K'
1067
+ if sz > 1024.:
1068
+ sz /= 1024
1069
+ unit = 'M'
1070
+ if sz > 1024.:
1071
+ sz /= 1024
1072
+ unit = 'G'
1073
+ if sz > 1024.:
1074
+ sz /= 1024
1075
+ unit = 'T'
1076
+
1077
+ dot_file_contents = re.sub('<<<%s>>>' % s, '#%02x%02x%02x",label="%d%s' % (szpct*255, 0, (1-szpct)*255, sz, unit), dot_file_contents)
1078
+
1079
+ dot_file = tempfile.NamedTemporaryFile()
1080
+ dot_file.write(dot_file_contents)
1081
+ dot_file.flush()
1082
+ png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
1083
+ subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
1084
+ import webbrowser
1085
+ webbrowser.open('file://' + png_file.name)
1086
+
1087
+
1088
+ def tree_iterator(self, visited=None, path=None):
1089
+ '''
1090
+ Generator function that traverse the dr tree start from this node (self).
1091
+ '''
1092
+ if visited is None:
1093
+ visited = set()
1094
+
1095
+ if self not in visited:
1096
+ if path and isinstance(path, list):
1097
+ path.append(self)
1098
+
1099
+ visited.add(self)
1100
+ yield self
1101
+
1102
+ if not hasattr(self, 'dterms'):
1103
+ yield
1104
+
1105
+ for dterm in self.dterms:
1106
+ if hasattr(self, dterm):
1107
+ child = getattr(self, dterm)
1108
+ if hasattr(child, 'dterms') or hasattr(child, 'terms'):
1109
+ for node in child.tree_iterator(visited):
1110
+ yield node
1111
+
1112
+ def floor(self):
1113
+ return floor(self)
1114
+
1115
+ def ceil(self):
1116
+ return ceil(self)
1117
+
1118
+ def dot(self, other):
1119
+ return dot(self, other)
1120
+
1121
+ def cumsum(self, axis=None):
1122
+ return cumsum(a=self, axis=axis)
1123
+
1124
+ def min(self, axis=None):
1125
+ return amin(a=self, axis=axis)
1126
+
1127
+ def max(self, axis=None):
1128
+ return amax(a=self, axis=axis)
1129
+
1130
+ ########################################################
1131
+ # Operator overloads
1132
+
1133
+ def __pos__(self): return self
1134
+ def __neg__(self): return negative(self)
1135
+
1136
+ def __add__ (self, other): return add(a=self, b=other)
1137
+ def __radd__(self, other): return add(a=other, b=self)
1138
+
1139
+ def __sub__ (self, other): return subtract(a=self, b=other)
1140
+ def __rsub__(self, other): return subtract(a=other, b=self)
1141
+
1142
+ def __mul__ (self, other): return multiply(a=self, b=other)
1143
+ def __rmul__(self, other): return multiply(a=other, b=self)
1144
+
1145
+ def __div__ (self, other): return divide(x1=self, x2=other)
1146
+ def __truediv__ (self, other): return divide(x1=self, x2=other)
1147
+ def __rdiv__(self, other): return divide(x1=other, x2=self)
1148
+
1149
+ def __pow__ (self, other): return power(x=self, pow=other)
1150
+ def __rpow__(self, other): return power(x=other, pow=self)
1151
+
1152
+ def __rand__(self, other): return self.__and__(other)
1153
+
1154
+ def __abs__ (self): return abs(self)
1155
+
1156
+ def __gt__(self, other): return greater(self, other)
1157
+ def __ge__(self, other): return greater_equal(self, other)
1158
+
1159
+ def __lt__(self, other): return less(self, other)
1160
+ def __le__(self, other): return less_equal(self, other)
1161
+
1162
+ def __ne__(self, other): return not_equal(self, other)
1163
+
1164
+ # not added yet because of weak key dict conflicts
1165
+ #def __eq__(self, other): return equal(self, other)
1166
+
1167
+
1168
+ Ch._reserved_kw = set(Ch.__dict__.keys())
1169
+
1170
+
1171
+ class MatVecMult(Ch):
1172
+ terms = 'mtx'
1173
+ dterms = 'vec'
1174
+ def compute_r(self):
1175
+ result = self.mtx.dot(col(self.vec.r.ravel())).ravel()
1176
+ if len(self.vec.r.shape) > 1 and self.vec.r.shape[1] > 1:
1177
+ result = result.reshape((-1,self.vec.r.shape[1]))
1178
+ return result
1179
+
1180
+ def compute_dr_wrt(self, wrt):
1181
+ if wrt is self.vec:
1182
+ return sp.csc_matrix(self.mtx)
1183
+
1184
+
1185
+ #def depends_on(*dependencies):
1186
+ # def _depends_on(func):
1187
+ # @wraps(func)
1188
+ # def with_caching(self, *args, **kwargs):
1189
+ # return func(self, *args, **kwargs)
1190
+ # return property(with_caching)
1191
+ # return _depends_on
1192
+
1193
+
1194
+ def depends_on(*dependencies):
1195
+ deps = set()
1196
+ for dep in dependencies:
1197
+ if isinstance(dep, str):
1198
+ deps.add(dep)
1199
+ else:
1200
+ [deps.add(d) for d in dep]
1201
+
1202
+ def _depends_on(func):
1203
+ want_out = 'out' in inspect.getfullargspec(func).args
1204
+
1205
+ @wraps(func)
1206
+ def with_caching(self, *args, **kwargs):
1207
+ func_name = func.__name__
1208
+ sdf = self._depends_on_deps[func_name]
1209
+ if sdf['out_of_date'] == True:
1210
+ #tm = time.time()
1211
+ if want_out:
1212
+ kwargs['out'] = sdf['value']
1213
+ sdf['value'] = func(self, *args, **kwargs)
1214
+ sdf['out_of_date'] = False
1215
+ #print 'recomputed %s in %.2e' % (func_name, time.time() - tm)
1216
+ return sdf['value']
1217
+ with_caching.deps = deps # set(dependencies)
1218
+ result = property(with_caching)
1219
+ return result
1220
+ return _depends_on
1221
+
1222
+
1223
+
1224
+ class ChHandle(Ch):
1225
+ dterms = ('x',)
1226
+
1227
+ def compute_r(self):
1228
+ assert(self.x is not self)
1229
+ return self.x.r
1230
+
1231
+ def compute_dr_wrt(self, wrt):
1232
+ if wrt is self.x:
1233
+ return 1
1234
+
1235
+
1236
+ class ChLambda(Ch):
1237
+ terms = ['lmb', 'initial_args']
1238
+ dterms = []
1239
+ term_order = ['lmb', 'initial_args']
1240
+
1241
+ def on_changed(self, which):
1242
+ for argname in set(which).intersection(set(self.args.keys())):
1243
+ self.args[argname].x = getattr(self, argname)
1244
+
1245
+ def __init__(self, lmb, initial_args=None):
1246
+ argspec_args = inspect.getfullargspec(lmb).args
1247
+ args = {argname: ChHandle(x=Ch(idx)) for idx, argname in enumerate(argspec_args)}
1248
+ if initial_args is not None:
1249
+ for initial_arg in initial_args:
1250
+ if initial_arg in args:
1251
+ args[initial_arg].x = initial_args[initial_arg]
1252
+ result = lmb(**args)
1253
+ for argname, arg in list(args.items()):
1254
+ if result.is_dr_wrt(arg.x):
1255
+ self.add_dterm(argname, arg.x)
1256
+ else:
1257
+ self.terms.append(argname)
1258
+ setattr(self, argname, arg.x)
1259
+ self.args = args
1260
+ self.add_dterm('_result', result)
1261
+
1262
+ def __getstate__(self):
1263
+ # Have to get rid of lambda for serialization
1264
+ if hasattr(self, 'lmb'):
1265
+ self.lmb = None
1266
+ return super(self.__class__, self).__getstate__()
1267
+
1268
+
1269
+ def compute_r(self):
1270
+ return self._result.r
1271
+
1272
+ def compute_dr_wrt(self, wrt):
1273
+ if wrt is self._result:
1274
+ return 1
1275
+
1276
+ # ChGroup is similar to ChLambda in that it's designed to expose the "internal"
1277
+ # inputs of result but, unlike ChLambda, result is kept internal and called when
1278
+ # compute_r and compute_dr_wrt is called to compute the relevant Jacobians.
1279
+ # This provides a way of effectively applying the chain rule in a different order.
1280
+ class ChGroup(Ch):
1281
+ terms = ['result', 'args']
1282
+ dterms = []
1283
+ term_order = ['result', 'args']
1284
+
1285
+ def on_changed(self, which):
1286
+ for argname in set(which).intersection(set(self.args.keys())):
1287
+ if not self.args[argname].x is getattr(self, argname) :
1288
+ self.args[argname].x = getattr(self, argname)
1289
+
1290
+ # right now the entries in args have to refer to terms/dterms of result,
1291
+ # it would be better if they could be "internal" as well, but for now the idea
1292
+ # is that result may itself be a ChLambda.
1293
+ def __init__(self, result, args):
1294
+ self.args = { argname: ChHandle(x=arg) for argname, arg in list(args.items()) }
1295
+ for argname, arg in list(self.args.items()):
1296
+ setattr(result, argname, arg)
1297
+ if result.is_dr_wrt(arg.x):
1298
+ self.add_dterm(argname, arg.x)
1299
+ else:
1300
+ self.terms.append(argname)
1301
+ setattr(self, argname, arg.x)
1302
+ self._result = result
1303
+
1304
+ def compute_r(self):
1305
+ return self._result.r
1306
+
1307
+ def compute_dr_wrt(self, wrt):
1308
+ return self._result.dr_wrt(wrt)
1309
+
1310
+ from .ch_ops import *
1311
+ from .ch_ops import __all__ as all_ch_ops
1312
+ __all__ += all_ch_ops
1313
+
1314
+ from .reordering import *
1315
+ from .reordering import Permute
1316
+ from .reordering import __all__ as all_reordering
1317
+ __all__ += all_reordering
1318
+
1319
+
1320
+ from . import linalg
1321
+ from . import ch_random as random
1322
+ __all__ += ['linalg', 'random']
1323
+
1324
+
1325
+
1326
+
1327
+
1328
+ class tst(Ch):
1329
+ dterms = ['a', 'b', 'c']
1330
+
1331
+ def compute_r(self):
1332
+ return self.a.r + self.b.r + self.c.r
1333
+
1334
+ def compute_dr_wrt(self, wrt):
1335
+ return 1
1336
+
1337
+ def main():
1338
+ foo = tst
1339
+
1340
+ x10 = Ch(10)
1341
+ x20 = Ch(20)
1342
+ x30 = Ch(30)
1343
+
1344
+ tmp = ChLambda(lambda x, y, z: Ch(1) + Ch(2) * Ch(3) + 4)
1345
+ print(tmp.dr_wrt(tmp.x))
1346
+ import pdb; pdb.set_trace()
1347
+ #a(b(c(d(e(f),g),h)))
1348
+
1349
+ blah = tst(x10, x20, x30)
1350
+
1351
+ print(blah.r)
1352
+
1353
+
1354
+ print(foo)
1355
+
1356
+ import pdb; pdb.set_trace()
1357
+
1358
+ # import unittest
1359
+ # from test_ch import TestCh
1360
+ # suite = unittest.TestLoader().loadTestsFromTestCase(TestCh)
1361
+ # unittest.TextTestRunner(verbosity=2).run(suite)
1362
+
1363
+
1364
+
1365
+ if __name__ == '__main__':
1366
+ main()
1367
+
vendor/chumpy/chumpy/ch_ops.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+ # Numpy functions
10
+ __all__ = ['array', 'amax','amin', 'max', 'min', 'maximum','minimum','nanmax','nanmin',
11
+ 'sum', 'exp', 'log', 'mean','std', 'var',
12
+ 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
13
+ 'sqrt', 'square', 'absolute', 'abs', 'clip',
14
+ 'power',
15
+ 'add', 'divide', 'multiply', 'negative', 'subtract', 'reciprocal',
16
+ 'nan_to_num',
17
+ 'dot', 'cumsum',
18
+ 'floor', 'ceil',
19
+ 'greater', 'greater_equal', 'less', 'less_equal', 'equal', 'not_equal',
20
+ 'nonzero', 'ascontiguousarray', 'asfarray', 'arange', 'asarray', 'copy',
21
+ 'cross',
22
+ 'shape', 'sign']
23
+
24
+
25
+ __all__ += ['SumOfSquares',
26
+ 'NanDivide', ]
27
+
28
+
29
+ # These can be wrapped directly as Ch(routine(*args, **kwargs)),
30
+ # so that for example "ch.eye(3)" translates into Ch(np.eye(3))
31
+ numpy_array_creation_routines = [
32
+ 'empty','empty_like','eye','identity','ones','ones_like','zeros','zeros_like',
33
+ 'array',
34
+ 'arange','linspace','logspace','meshgrid','mgrid','ogrid',
35
+ 'fromfunction', 'fromiter', 'meshgrid', 'tri'
36
+ ]
37
+
38
+
39
+ wont_implement = ['asanyarray', 'asmatrix', 'frombuffer', 'copy', 'fromfile', 'fromstring', 'loadtxt', 'copyto', 'asmatrix', 'asfortranarray', 'asscalar', 'require']
40
+ not_yet_implemented = ['tril', 'triu', 'vander']
41
+
42
+ __all__ += not_yet_implemented
43
+ __all__ += wont_implement
44
+ __all__ += numpy_array_creation_routines
45
+
46
+
47
+ from .ch import Ch
48
+ import six
49
+ import numpy as np
50
+ import warnings
51
+ from six.moves import cPickle as pickle
52
+ import scipy.sparse as sp
53
+ from .utils import row, col
54
+ from copy import copy as copy_copy
55
+ from functools import reduce
56
+
57
+ __all__ += ['pi', 'set_printoptions']
58
+ pi = np.pi
59
+ set_printoptions = np.set_printoptions
60
+ arange = np.arange
61
+
62
+ for rtn in ['argmax', 'nanargmax', 'argmin', 'nanargmin']:
63
+ exec('def %s(a, axis=None) : return np.%s(a.r, axis) if hasattr(a, "compute_r") else np.%s(a, axis)' % (rtn, rtn, rtn))
64
+ __all__ += [rtn]
65
+
66
+ for rtn in ['argwhere', 'nonzero', 'flatnonzero']:
67
+ exec('def %s(a) : return np.%s(a.r) if hasattr(a, "compute_r") else np.%s(a)' % (rtn, rtn, rtn))
68
+ __all__ += [rtn]
69
+
70
+ for rtn in numpy_array_creation_routines:
71
+ exec('def %s(*args, **kwargs) : return Ch(np.%s(*args, **kwargs))' % (rtn, rtn))
72
+
73
+
74
+ class WontImplement(Exception):
75
+ pass
76
+
77
+ for rtn in wont_implement:
78
+ exec('def %s(*args, **kwargs) : raise WontImplement' % (rtn))
79
+
80
+ for rtn in not_yet_implemented:
81
+ exec('def %s(*args, **kwargs) : raise NotImplementedError' % (rtn))
82
+
83
+ def asarray(a, dtype=None, order=None):
84
+ assert(dtype is None or dtype is np.float64)
85
+ assert(order == 'C' or order is None)
86
+ if hasattr(a, 'dterms'):
87
+ return a
88
+ return Ch(np.asarray(a, dtype, order))
89
+
90
+ # Everythign is always c-contiguous
91
+ def ascontiguousarray(a, dtype=None): return a
92
+
93
+ # Everything is always float
94
+ asfarray = ascontiguousarray
95
+
96
+ def copy(self):
97
+ return pickle.loads(pickle.dumps(self))
98
+
99
+ def asfortranarray(a, dtype=None): raise WontImplement
100
+
101
+
102
+ class Simpleton(Ch):
103
+ dterms = 'x'
104
+ def compute_dr_wrt(self, wrt):
105
+ return None
106
+
107
+ class floor(Simpleton):
108
+ def compute_r(self): return np.floor(self.x.r)
109
+
110
+ class ceil(Simpleton):
111
+ def compute_r(self): return np.ceil(self.x.r)
112
+
113
+ class sign(Simpleton):
114
+ def compute_r(self): return np.sign(self.x.r)
115
+
116
+ class Cross(Ch):
117
+ dterms = 'a', 'b'
118
+ terms = 'axisa', 'axisb', 'axisc', 'axis'
119
+ term_order = 'a', 'b', 'axisa', 'axisb', 'axisc', 'axis'
120
+
121
+ def compute_r(self):
122
+ return np.cross(self.a.r, self.b.r, self.axisa, self.axisb, self.axisc, self.axis)
123
+
124
+
125
+ def _load_crossprod_cache(self, h, w):
126
+ if not hasattr(self, '_w'):
127
+ self._w = 0
128
+ self._h = 0
129
+
130
+ if h!=self._h or w!=self._w:
131
+ sz = h*w
132
+ rng = np.arange(sz)
133
+ self._JS = np.repeat(rng.reshape((-1,w)), w, axis=0).ravel()
134
+ self._IS = np.repeat(rng, w)
135
+ self._tiled_identity = np.tile(np.eye(w), (h, 1))
136
+ self._h = h
137
+ self._w = w
138
+
139
+ return self._tiled_identity, self._IS, self._JS,
140
+
141
+
142
+
143
+ # Could be at least 2x faster, with some work
144
+ def compute_dr_wrt(self, wrt):
145
+ if wrt is not self.a and wrt is not self.b:
146
+ return
147
+
148
+ sz = self.a.size
149
+ h, w = self.a.shape
150
+ tiled_identity, IS, JS = self._load_crossprod_cache(h, w)
151
+
152
+ #import time
153
+ #tm = time.time()
154
+ if wrt is self.a:
155
+ rp = np.repeat(-self.b.r, w, axis=0)
156
+ result = np.cross(
157
+ tiled_identity,
158
+ rp,
159
+ self.axisa,
160
+ self.axisb,
161
+ self.axisc,
162
+ self.axis)
163
+
164
+ elif wrt is self.b:
165
+ result = np.cross(
166
+ np.repeat(-self.a.r, w, axis=0),
167
+ tiled_identity,
168
+ self.axisa,
169
+ self.axisb,
170
+ self.axisc,
171
+ self.axis)
172
+
173
+ # rng = np.arange(sz)
174
+ # JS = np.repeat(rng.reshape((-1,w)), w, axis=0).ravel()
175
+ # IS = np.repeat(rng, w)
176
+ data = result.ravel()
177
+ result = sp.csc_matrix((data, (IS,JS)), shape=(self.size, wrt.size))
178
+ #import pdb; pdb.set_trace()
179
+ #print 'B TOOK %es' % (time.time() -tm )
180
+ return result
181
+
182
+ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
183
+ return Cross(a, b, axisa, axisb, axisc, axis)
184
+
185
+
186
+
187
+
188
+ class cumsum(Ch):
189
+ dterms = 'a'
190
+ terms = 'axis'
191
+ term_order = 'a', 'axis'
192
+
193
+ def on_changed(self, which):
194
+ if not hasattr(self, 'axis'):
195
+ self.axis = None
196
+
197
+ def compute_r(self):
198
+ return np.cumsum(self.a.r, axis=self.axis)
199
+
200
+ def compute_dr_wrt(self, wrt):
201
+ if wrt is not self.a:
202
+ return None
203
+
204
+ if self.axis is not None:
205
+ raise NotImplementedError
206
+
207
+ IS = np.tile(row(np.arange(self.a.size)), (self.a.size, 1))
208
+ JS = IS.T
209
+ IS = IS.ravel()
210
+ JS = JS.ravel()
211
+ which = IS >= JS
212
+ IS = IS[which]
213
+ JS = JS[which]
214
+ data = np.ones_like(IS)
215
+ result = sp.csc_matrix((data, (IS, JS)), shape=(self.a.size, self.a.size))
216
+ return result
217
+
218
+
219
+ class UnaryElemwise(Ch):
220
+ dterms = 'x'
221
+
222
+ def compute_r(self):
223
+ return self._r(self.x.r)
224
+
225
+ def compute_dr_wrt(self, wrt):
226
+ if wrt is self.x:
227
+ result = self._d(self.x.r)
228
+ return sp.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
229
+
230
+
231
+ class nan_to_num(UnaryElemwise):
232
+ _r = lambda self, x : np.nan_to_num(x)
233
+ _d = lambda self, x : np.asarray(np.isfinite(x), np.float64)
234
+
235
+ class reciprocal(UnaryElemwise):
236
+ _r = np.reciprocal
237
+ _d = lambda self, x : -np.reciprocal(np.square(x))
238
+
239
+ class square(UnaryElemwise):
240
+ _r = np.square
241
+ _d = lambda self, x : x * 2.
242
+
243
+ def my_power(a, b):
244
+ with warnings.catch_warnings():
245
+ warnings.filterwarnings("ignore",category=RuntimeWarning)
246
+ return np.nan_to_num(np.power(a, b))
247
+
248
+ class sqrt(UnaryElemwise):
249
+ _r = np.sqrt
250
+ _d = lambda self, x : .5 * my_power(x, -0.5)
251
+
252
+ class exp(UnaryElemwise):
253
+ _r = np.exp
254
+ _d = np.exp
255
+
256
+ class log(UnaryElemwise):
257
+ _r = np.log
258
+ _d = np.reciprocal
259
+
260
+ class sin(UnaryElemwise):
261
+ _r = np.sin
262
+ _d = np.cos
263
+
264
+ class arcsin(UnaryElemwise):
265
+ _r = np.arcsin
266
+ _d = lambda self, x : np.reciprocal(np.sqrt(1.-np.square(x)))
267
+
268
+ class cos(UnaryElemwise):
269
+ _r = np.cos
270
+ _d = lambda self, x : -np.sin(x)
271
+
272
+ class arccos(UnaryElemwise):
273
+ _r = np.arccos
274
+ _d = lambda self, x : -np.reciprocal(np.sqrt(1.-np.square(x)))
275
+
276
+ class tan(UnaryElemwise):
277
+ _r = np.tan
278
+ _d = lambda self, x : np.reciprocal(np.cos(x)**2.)
279
+
280
+ class arctan(UnaryElemwise):
281
+ _r = np.arctan
282
+ _d = lambda self, x : np.reciprocal(np.square(x)+1.)
283
+
284
+ class negative(UnaryElemwise):
285
+ _r = np.negative
286
+ _d = lambda self, x : np.negative(np.ones_like(x))
287
+
288
+ class absolute(UnaryElemwise):
289
+ _r = np.abs
290
+ _d = lambda self, x : (x>0)*2-1.
291
+
292
+ abs = absolute
293
+
294
+ class clip(Ch):
295
+ dterms = 'a'
296
+ terms = 'a_min', 'a_max'
297
+ term_order = 'a', 'a_min', 'a_max'
298
+
299
+ def compute_r(self):
300
+ return np.clip(self.a.r, self.a_min, self.a_max)
301
+
302
+ def compute_dr_wrt(self, wrt):
303
+ if wrt is self.a:
304
+ result = np.asarray((self.r != self.a_min) & (self.r != self.a_max), np.float64)
305
+ return sp.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
306
+
307
+ class sum(Ch):
308
+ dterms = 'x',
309
+ terms = 'axis',
310
+ term_order = 'x', 'axis'
311
+
312
+ def on_changed(self, which):
313
+ if not hasattr(self, 'axis'):
314
+ self.axis = None
315
+ if not hasattr(self, 'dr_cache'):
316
+ self.dr_cache = {}
317
+
318
+ def compute_r(self):
319
+ return np.sum(self.x.r, axis=self.axis)
320
+
321
+ def compute_dr_wrt(self, wrt):
322
+ if wrt is not self.x:
323
+ return
324
+ if self.axis == None:
325
+ return row(np.ones((1, len(self.x.r.ravel()))))
326
+ else:
327
+ uid = tuple(list(self.x.shape) + [self.axis])
328
+ if uid not in self.dr_cache:
329
+ idxs_presum = np.arange(self.x.size).reshape(self.x.shape)
330
+ idxs_presum = np.rollaxis(idxs_presum, self.axis, 0)
331
+ idxs_postsum = np.arange(self.r.size).reshape(self.r.shape)
332
+ tp = np.ones(idxs_presum.ndim, dtype=np.uint32)
333
+ tp[0] = idxs_presum.shape[0]
334
+ idxs_postsum = np.tile(idxs_postsum, tp)
335
+ data = np.ones(idxs_postsum.size)
336
+ result = sp.csc_matrix((data, (idxs_postsum.ravel(), idxs_presum.ravel())), (self.r.size, wrt.size))
337
+ self.dr_cache[uid] = result
338
+ return self.dr_cache[uid]
339
+
340
+
341
+ class mean(Ch):
342
+ dterms = 'x',
343
+ terms = 'axis',
344
+ term_order = 'x', 'axis'
345
+
346
+ def on_changed(self, which):
347
+ if not hasattr(self, 'axis'):
348
+ self.axis = None
349
+ if not hasattr(self, 'dr_cache'):
350
+ self.dr_cache = {}
351
+
352
+ def compute_r(self):
353
+ return np.array(np.mean(self.x.r, axis=self.axis))
354
+
355
+ def compute_dr_wrt(self, wrt):
356
+ if wrt is not self.x:
357
+ return
358
+ if self.axis == None:
359
+ return row(np.ones((1, len(self.x.r))))/len(self.x.r)
360
+ else:
361
+ uid = tuple(list(self.x.shape) + [self.axis])
362
+ if uid not in self.dr_cache:
363
+ idxs_presum = np.arange(self.x.size).reshape(self.x.shape)
364
+ idxs_presum = np.rollaxis(idxs_presum, self.axis, 0)
365
+ idxs_postsum = np.arange(self.r.size).reshape(self.r.shape)
366
+ tp = np.ones(idxs_presum.ndim, dtype=np.uint32)
367
+ tp[0] = idxs_presum.shape[0]
368
+ idxs_postsum = np.tile(idxs_postsum, tp)
369
+ data = np.ones(idxs_postsum.size) / self.x.shape[self.axis]
370
+ result = sp.csc_matrix((data, (idxs_postsum.ravel(), idxs_presum.ravel())), (self.r.size, wrt.size))
371
+ self.dr_cache[uid] = result
372
+ return self.dr_cache[uid]
373
+
374
+
375
+ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
376
+ if (dtype != None or out != None or ddof != 0 or keepdims != False):
377
+ raise NotImplementedException('Unimplemented for non-default dtype, out, ddof, and keepdims.')
378
+ return mean(a**2., axis=axis)
379
+
380
+ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
381
+ if (dtype != None or out != None or ddof != 0 or keepdims != False):
382
+ raise NotImplementedException('Unimplemented for non-default dtype, out, ddof, and keepdims.')
383
+ return sqrt(var(a, axis=axis))
384
+
385
+
386
+ class SumOfSquares(Ch):
387
+ dterms = 'x',
388
+
389
+ def compute_r(self):
390
+ return np.sum(self.x.r.ravel()**2.)
391
+
392
+ def compute_dr_wrt(self, wrt):
393
+ if wrt is self.x:
394
+ return row(self.x.r.ravel()*2.)
395
+
396
+
397
+ class divide (Ch):
398
+ dterms = 'x1', 'x2'
399
+
400
+ def compute_r(self):
401
+ return self.x1.r / self.x2.r
402
+
403
+ def compute_dr_wrt(self, wrt):
404
+
405
+ if (wrt is self.x1) == (wrt is self.x2):
406
+ return None
407
+
408
+ IS, JS, input_sz, output_sz = _broadcast_setup(self.x1, self.x2, wrt)
409
+
410
+ x1r, x2r = self.x1.r, self.x2.r
411
+ if wrt is self.x1:
412
+ data = (np.ones_like(x1r) / x2r).ravel()
413
+ else:
414
+ data = (-x1r / (x2r*x2r)).ravel()
415
+
416
+ return sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.r.size))
417
+
418
+
419
+
420
+
421
+ class NanDivide(divide):
422
+ dterms = 'x1', 'x2'
423
+
424
+ def compute_r(self):
425
+ with warnings.catch_warnings():
426
+ warnings.simplefilter("ignore")
427
+ result = super(self.__class__, self).compute_r()
428
+ shape = result.shape
429
+ result = result.ravel()
430
+ result[np.isinf(result)] = 0
431
+ result[np.isnan(result)] = 0
432
+ return result.reshape(shape)
433
+
434
+ def compute_dr_wrt(self, wrt):
435
+ with warnings.catch_warnings():
436
+ warnings.simplefilter("ignore")
437
+ result = super(self.__class__, self).compute_dr_wrt(wrt)
438
+ if result is not None:
439
+ result = result.copy()
440
+ if sp.issparse(result):
441
+ result.data[np.isinf(result.data)] = 0
442
+ result.data[np.isnan(result.data)] = 0
443
+ return result
444
+ else:
445
+ rr = result.ravel()
446
+ rr[np.isnan(rr)] = 0.
447
+ rr[np.isinf(rr)] = 0.
448
+ return result
449
+
450
+
451
+ def shape(a):
452
+ return a.shape if hasattr(a, 'shape') else np.shape(a)
453
+
454
+
455
+ _bs_setup_data1 = {}
456
+ _bs_setup_data2 = {}
457
+ def _broadcast_matrix(a, b, wrt, data):
458
+ global _bs_setup_data1, _bs_setup_data2
459
+
460
+ if len(set((a.shape, b.shape))) == 1:
461
+ uid = a.shape
462
+ if uid not in _bs_setup_data1:
463
+ asz = a.size
464
+ IS = np.arange(asz)
465
+ _bs_setup_data1[uid] = sp.csc_matrix((np.empty(asz), (IS, IS)), shape=(asz, asz))
466
+ result = copy_copy(_bs_setup_data1[uid])
467
+ if isinstance(data, np.ndarray):
468
+ result.data = data.ravel()
469
+ else: # assumed scalar
470
+ result.data = np.empty(result.nnz)
471
+ result.data.fill(data)
472
+ else:
473
+ uid = (a.shape, b.shape, wrt is a, wrt is b)
474
+ if uid not in _bs_setup_data2:
475
+ input_sz = wrt.size
476
+ output_sz = np.broadcast(a.r, b.r).size
477
+ a2 = np.arange(a.size).reshape(a.shape) if wrt is a else np.zeros(a.shape)
478
+ b2 = np.arange(b.size).reshape(b.shape) if (wrt is b and wrt is not a) else np.zeros(b.shape)
479
+ IS = np.arange(output_sz)
480
+ JS = np.asarray((np.add(a2,b2)).ravel(), np.uint32)
481
+
482
+ _bs_setup_data2[uid] = sp.csc_matrix((np.arange(IS.size), (IS, JS)), shape=(output_sz, input_sz))
483
+
484
+ result = copy_copy(_bs_setup_data2[uid])
485
+ if isinstance(data, np.ndarray):
486
+ result.data = data[result.data]
487
+ else: # assumed scalar
488
+ result.data = np.empty(result.nnz)
489
+ result.data.fill(data)
490
+
491
+ if np.prod(result.shape) == 1:
492
+ return np.array(data)
493
+ else:
494
+ return result
495
+
496
+
497
+
498
+
499
+ broadcast_shape_cache = {}
500
+ def broadcast_shape(a_shape, b_shape):
501
+ global broadcast_shape_cache
502
+
503
+ raise Exception('This function is probably a bad idea, because shape is not cached and overquerying can occur.')
504
+
505
+ uid = (a_shape, b_shape)
506
+
507
+ if uid not in broadcast_shape_cache:
508
+ la = len(a_shape)
509
+ lb = len(b_shape)
510
+ ln = la if la > lb else lb
511
+
512
+ ash = np.ones(ln, dtype=np.uint32)
513
+ bsh = np.ones(ln, dtype=np.uint32)
514
+ ash[-la:] = a_shape
515
+ bsh[-lb:] = b_shape
516
+
517
+ our_result = np.max(np.vstack((ash, bsh)), axis=0)
518
+
519
+ if False:
520
+ numpy_result = np.broadcast(np.empty(a_shape), np.empty(b_shape)).shape
521
+ #print 'aaa' + str(our_result)
522
+ #print 'bbb' + str(numpy_result)
523
+ if not np.array_equal(our_result, numpy_result):
524
+ raise Exception('numpy result not equal to our result')
525
+ assert(np.array_equal(our_result, numpy_result))
526
+
527
+ broadcast_shape_cache[uid] = tuple(our_result)
528
+ return broadcast_shape_cache[uid]
529
+
530
+
531
+ def _broadcast_setup(a, b, wrt):
532
+ if len(set((a.shape, b.shape))) == 1:
533
+ asz = a.size
534
+ IS = np.arange(asz)
535
+ return IS, IS, asz, asz
536
+ input_sz = wrt.r.size
537
+ output_sz = np.broadcast(a.r, b.r).size
538
+ a2 = np.arange(a.size).reshape(a.shape) if wrt is a else np.zeros(a.shape)
539
+ b2 = np.arange(b.size).reshape(b.shape) if (wrt is b and wrt is not a) else np.zeros(b.shape)
540
+ IS = np.arange(output_sz)
541
+ JS = np.asarray((np.add(a2,b2)).ravel(), np.uint32)
542
+ return IS, JS, input_sz, output_sz
543
+
544
+
545
+
546
+ class add(Ch):
547
+ dterms = 'a', 'b'
548
+
549
+ def compute_r(self):
550
+ return self.a.r + self.b.r
551
+
552
+ def compute_dr_wrt(self, wrt):
553
+ if wrt is not self.a and wrt is not self.b:
554
+ return None
555
+
556
+ m = 2. if self.a is self.b else 1.
557
+ return _broadcast_matrix(self.a, self.b, wrt, m)
558
+
559
+
560
+
561
+
562
+ class subtract(Ch):
563
+ dterms = 'a', 'b'
564
+
565
+ def compute_r(self):
566
+ return self.a.r - self.b.r
567
+
568
+ def compute_dr_wrt(self, wrt):
569
+ if (wrt is self.a) == (wrt is self.b):
570
+ return None
571
+
572
+ m = 1. if wrt is self.a else -1.
573
+ return _broadcast_matrix(self.a, self.b, wrt, m)
574
+
575
+
576
+
577
+
578
+
579
+ class power (Ch):
580
+ """Given vector \f$x\f$, computes \f$x^2\f$ and \f$\frac{dx^2}{x}\f$"""
581
+ dterms = 'x', 'pow'
582
+
583
+ def compute_r(self):
584
+ return self.safe_power(self.x.r, self.pow.r)
585
+
586
+ def compute_dr_wrt(self, wrt):
587
+
588
+ if wrt is not self.x and wrt is not self.pow:
589
+ return None
590
+
591
+ x, pow = self.x.r, self.pow.r
592
+ result = []
593
+ if wrt is self.x:
594
+ result.append(pow * self.safe_power(x, pow-1.))
595
+ if wrt is self.pow:
596
+ result.append(np.log(x) * self.safe_power(x, pow))
597
+
598
+ data = reduce(lambda x, y : x + y, result).ravel()
599
+
600
+ return _broadcast_matrix(self.x, self.pow, wrt, data)
601
+
602
+
603
+ def safe_power(self, x, sigma):
604
+ # This throws a RuntimeWarning sometimes, but then the infs are corrected below
605
+ result = np.power(x, sigma)
606
+ result.ravel()[np.isinf(result.ravel())] = 0
607
+ return result
608
+
609
+
610
+
611
+
612
+
613
+ class A_extremum(Ch):
614
+ """Superclass for various min and max subclasses"""
615
+ dterms = 'a'
616
+ terms = 'axis'
617
+ term_order = 'a', 'axis'
618
+
619
+ def f(self, axis): raise NotImplementedError
620
+ def argf(self, axis): raise NotImplementedError
621
+
622
+ def on_changed(self, which):
623
+ if not hasattr(self, 'axis'):
624
+ self.axis = None
625
+
626
+ def compute_r(self):
627
+ return self.f(self.a.r, axis=self.axis)
628
+
629
+ def compute_dr_wrt(self, wrt):
630
+ if wrt is self.a:
631
+
632
+ mn, stride = self._stride_for_axis(self.axis, self.a.r)
633
+ JS = np.asarray(np.round(mn + stride * self.argf(self.a.r, axis=self.axis)), dtype=np.uint32).ravel()
634
+ IS = np.arange(JS.size)
635
+ data = np.ones(JS.size)
636
+
637
+ if self.r.size * wrt.r.size == 1:
638
+ return data.ravel()[0]
639
+ return sp.csc_matrix((data, (IS, JS)), shape = (self.r.size, wrt.r.size))
640
+
641
+ def _stride_for_axis(self,axis, mtx):
642
+ if axis is None:
643
+ mn = np.array([0])
644
+ stride = np.array([1])
645
+ else:
646
+ # TODO: make this less expensive. Shouldn't need to call
647
+ # np.amin here probably
648
+ idxs = np.arange(mtx.size).reshape(mtx.shape)
649
+ mn = np.amin(idxs, axis=axis)
650
+ mtx_strides = np.array(mtx.strides)
651
+ stride = mtx_strides / np.min(mtx_strides) # go from bytes to num elements
652
+ stride = stride[axis]
653
+ return mn, stride
654
+
655
+
656
+ class amax(A_extremum):
657
+ def f(self, *args, **kwargs): return np.amax(*args, **kwargs)
658
+ def argf(self, *args, **kwargs): return np.argmax(*args, **kwargs)
659
+
660
+ max = amax
661
+
662
+ class amin(A_extremum):
663
+ def f(self, *args, **kwargs): return np.amin(*args, **kwargs)
664
+ def argf(self, *args, **kwargs): return np.argmin(*args, **kwargs)
665
+
666
+ min = amin
667
+
668
+ class nanmin(A_extremum):
669
+ def f(self, *args, **kwargs): return np.nanmin(*args, **kwargs)
670
+ def argf(self, *args, **kwargs): return np.nanargmin(*args, **kwargs)
671
+
672
+ class nanmax(A_extremum):
673
+ def f(self, *args, **kwargs): return np.nanmax(*args, **kwargs)
674
+ def argf(self, *args, **kwargs): return np.nanargmax(*args, **kwargs)
675
+
676
+
677
+ class Extremum(Ch):
678
+ dterms = 'a','b'
679
+
680
+ def compute_r(self): return self.f(self.a.r, self.b.r)
681
+
682
+ def compute_dr_wrt(self, wrt):
683
+ if wrt is not self.a and wrt is not self.b:
684
+ return None
685
+
686
+ IS, JS, input_sz, output_sz = _broadcast_setup(self.a, self.b, wrt)
687
+ if wrt is self.a:
688
+ whichmax = (self.r == self.f(self.a.r, self.b.r-self.f(1,-1))).ravel()
689
+ else:
690
+ whichmax = (self.r == self.f(self.b.r, self.a.r-self.f(1,-1))).ravel()
691
+ IS = IS[whichmax]
692
+ JS = JS[whichmax]
693
+ data = np.ones(JS.size)
694
+
695
+ return sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.r.size))
696
+
697
+ class maximum(Extremum):
698
+ def f(self, a, b): return np.maximum(a, b)
699
+
700
+ class minimum(Extremum):
701
+ def f(self, a, b): return np.minimum(a, b)
702
+
703
+
704
+ class multiply(Ch):
705
+ dterms = 'a', 'b'
706
+
707
+ def compute_r(self):
708
+ return self.a.r * self.b.r
709
+
710
+ def compute_dr_wrt(self, wrt):
711
+ if wrt is not self.a and wrt is not self.b:
712
+ return None
713
+
714
+ a2 = self.a.r if wrt is self.b else np.ones(self.a.shape)
715
+ b2 = self.b.r if (wrt is self.a and wrt is not self.b) else np.ones(self.b.shape)
716
+ data = (a2 * b2).ravel()
717
+
718
+ if self.a is self.b:
719
+ data *= 2.
720
+
721
+ return _broadcast_matrix(self.a, self.b, wrt, data)
722
+
723
+
724
+
725
+
726
+
727
+ class dot(Ch):
728
+ dterms = 'a', 'b'
729
+
730
+ def compute_r(self):
731
+ return self.a.r.dot(self.b.r)
732
+
733
+ def compute_d1(self):
734
+ # To stay consistent with numpy, we must upgrade 1D arrays to 2D
735
+ ar = row(self.a.r) if len(self.a.r.shape)<2 else self.a.r.reshape((-1, self.a.r.shape[-1]))
736
+ br = col(self.b.r) if len(self.b.r.shape)<2 else self.b.r.reshape((self.b.r.shape[0], -1))
737
+
738
+ if ar.ndim <= 2:
739
+ return sp.kron(sp.eye(ar.shape[0], ar.shape[0]),br.T)
740
+ else:
741
+ raise NotImplementedError
742
+
743
+ def compute_d2(self):
744
+
745
+ # To stay consistent with numpy, we must upgrade 1D arrays to 2D
746
+ ar = row(self.a.r) if len(self.a.r.shape)<2 else self.a.r.reshape((-1, self.a.r.shape[-1]))
747
+ br = col(self.b.r) if len(self.b.r.shape)<2 else self.b.r.reshape((self.b.r.shape[0], -1))
748
+
749
+ if br.ndim <= 1:
750
+ return self.ar
751
+ elif br.ndim <= 2:
752
+ return sp.kron(ar, sp.eye(br.shape[1],br.shape[1]))
753
+ else:
754
+ raise NotImplementedError
755
+
756
+
757
+ def compute_dr_wrt(self, wrt):
758
+
759
+ if wrt is self.a and wrt is self.b:
760
+ return self.compute_d1() + self.compute_d2()
761
+ elif wrt is self.a:
762
+ return self.compute_d1()
763
+ elif wrt is self.b:
764
+ return self.compute_d2()
765
+
766
+ class BinaryElemwiseNoDrv(Ch):
767
+ dterms = 'x1', 'x2'
768
+
769
+ def compute_r(self):
770
+ return self._f(self.x1.r, self.x2.r)
771
+
772
+ def compute_dr_wrt(self, wrt):
773
+ return None
774
+
775
+ class greater(BinaryElemwiseNoDrv):
776
+ def _f(self, a, b): return np.greater(a,b)
777
+
778
+ class greater_equal(BinaryElemwiseNoDrv):
779
+ def _f(self, a, b): return np.greater_equal(a,b)
780
+
781
+ class less(BinaryElemwiseNoDrv):
782
+ def _f(self, a, b): return np.less(a,b)
783
+
784
+ class less_equal(BinaryElemwiseNoDrv):
785
+ def _f(self, a, b): return np.less_equal(a,b)
786
+
787
+ class equal(BinaryElemwiseNoDrv):
788
+ def _f(self, a, b): return np.equal(a,b)
789
+
790
+ class not_equal(BinaryElemwiseNoDrv):
791
+ def _f(self, a, b): return np.not_equal(a,b)
792
+
793
+ def nonzero(a):
794
+ if hasattr(a, 'compute_r'):
795
+ a = a.r
796
+ return np.nonzero(a)
797
+
798
+ # Pull the code for tensordot in from numpy and reinterpret it using chumpy ops
799
+ import os
800
+ source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'np_tensordot.py')
801
+ with open(source_path, 'r') as f:
802
+ source_lines = f.readlines()
803
+ exec(''.join(source_lines))
804
+ __all__ += ['tensordot']
805
+
806
+
807
+
808
+ def main():
809
+ pass
810
+
811
+
812
+ if __name__ == '__main__':
813
+ main()
814
+
vendor/chumpy/chumpy/ch_random.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author(s): Matthew Loper
3
+
4
+ See LICENCE.txt for licensing and contact information.
5
+ """
6
+
7
+ import numpy.random
8
+ from .ch import Ch
9
+
10
+ api_not_implemented = ['choice','bytes','shuffle','permutation']
11
+
12
+ api_wrapped_simple = [
13
+ # simple random data
14
+ 'rand','randn','randint','random_integers','random_sample','random','ranf','sample',
15
+
16
+ # distributions
17
+ 'beta','binomial','chisquare','dirichlet','exponential','f','gamma','geometric','gumbel','hypergeometric',
18
+ 'laplace','logistic','lognormal','logseries','multinomial','multivariate_normal','negative_binomial',
19
+ 'noncentral_chisquare','noncentral_f','normal','pareto','poisson','power','rayleigh','standard_cauchy',
20
+ 'standard_exponential','standard_gamma','standard_normal','standard_t','triangular','uniform','vonmises',
21
+ 'wald','weibull','zipf']
22
+
23
+ api_wrapped_direct = ['seed', 'get_state', 'set_state']
24
+
25
+ for rtn in api_wrapped_simple:
26
+ exec('def %s(*args, **kwargs) : return Ch(numpy.random.%s(*args, **kwargs))' % (rtn, rtn))
27
+
28
+ for rtn in api_wrapped_direct:
29
+ exec('%s = numpy.random.%s' % (rtn, rtn))
30
+
31
+ __all__ = api_wrapped_simple + api_wrapped_direct
32
+
vendor/chumpy/chumpy/extras.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = 'matt'
2
+
3
+ from . import ch
4
+ import numpy as np
5
+ from .utils import row, col
6
+ import scipy.sparse as sp
7
+ import scipy.special
8
+
9
+ class Interp3D(ch.Ch):
10
+ dterms = 'locations'
11
+ terms = 'image'
12
+
13
+ def on_changed(self, which):
14
+ if 'image' in which:
15
+ self.gx, self.gy, self.gz = np.gradient(self.image)
16
+
17
+ def compute_r(self):
18
+ locations = self.locations.r.copy()
19
+ for i in range(3):
20
+ locations[:,i] = np.clip(locations[:,i], 0, self.image.shape[i]-1)
21
+ locs = np.floor(locations).astype(np.uint32)
22
+ result = self.image[locs[:,0], locs[:,1], locs[:,2]]
23
+ offset = (locations - locs)
24
+ dr = self.dr_wrt(self.locations).dot(offset.ravel())
25
+ return result + dr
26
+
27
+ def compute_dr_wrt(self, wrt):
28
+ if wrt is self.locations:
29
+ locations = self.locations.r.copy()
30
+ for i in range(3):
31
+ locations[:,i] = np.clip(locations[:,i], 0, self.image.shape[i]-1)
32
+ locations = locations.astype(np.uint32)
33
+
34
+ xc = col(self.gx[locations[:,0], locations[:,1], locations[:,2]])
35
+ yc = col(self.gy[locations[:,0], locations[:,1], locations[:,2]])
36
+ zc = col(self.gz[locations[:,0], locations[:,1], locations[:,2]])
37
+
38
+ data = np.vstack([xc.ravel(), yc.ravel(), zc.ravel()]).T.copy()
39
+ JS = np.arange(locations.size)
40
+ IS = JS // 3
41
+
42
+ return sp.csc_matrix((data.ravel(), (IS, JS)))
43
+
44
+
45
+ class gamma(ch.Ch):
46
+ dterms = 'x',
47
+
48
+ def compute_r(self):
49
+ return scipy.special.gamma(self.x.r)
50
+
51
+ def compute_dr_wrt(self, wrt):
52
+ if wrt is self.x:
53
+ d = scipy.special.polygamma(0, self.x.r)*self.r
54
+ return sp.diags([d.ravel()], [0])
55
+
56
+ # This function is based directly on the "moment" function
57
+ # in scipy, specifically in mstats_basic.py.
58
+ def moment(a, moment=1, axis=0):
59
+ if moment == 1:
60
+ # By definition the first moment about the mean is 0.
61
+ shape = list(a.shape)
62
+ del shape[axis]
63
+ if shape:
64
+ # return an actual array of the appropriate shape
65
+ return ch.zeros(shape, dtype=float)
66
+ else:
67
+ # the input was 1D, so return a scalar instead of a rank-0 array
68
+ return np.float64(0.0)
69
+ else:
70
+ mn = ch.expand_dims(a.mean(axis=axis), axis)
71
+ s = ch.power((a-mn), moment)
72
+ return s.mean(axis=axis)
vendor/chumpy/chumpy/linalg.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+
4
+ """
5
+ Author(s): Matthew Loper
6
+
7
+ See LICENCE.txt for licensing and contact information.
8
+ """
9
+
10
+
11
+ __all__ = ['inv', 'svd', 'det', 'slogdet', 'pinv', 'lstsq', 'norm']
12
+
13
+ import numpy as np
14
+ import scipy.sparse as sp
15
+ from .ch import Ch, depends_on
16
+ from .ch_ops import NanDivide
17
+ from .ch_ops import asarray as ch_asarray
18
+ from .ch_ops import sqrt as ch_sqrt
19
+ from .ch_ops import sum as ch_sum
20
+ from .reordering import concatenate as ch_concatenate
21
+ from .ch_random import randn as ch_random_randn
22
+ from .utils import row, col
23
+
24
+
25
+ try:
26
+ asarray = ch_asarray
27
+ import inspect
28
+ exec(''.join(inspect.getsourcelines(np.linalg.tensorinv)[0]))
29
+ __all__.append('tensorinv')
30
+ except: pass
31
+
32
+ def norm(x, ord=None, axis=None):
33
+ if ord is not None or axis is not None:
34
+ raise NotImplementedError("'ord' and 'axis' should be None for now.")
35
+
36
+ return ch_sqrt(ch_sum(x**2))
37
+
38
+ # This version works but derivatives are too slow b/c of nested loop in Svd implementation.
39
+ # def lstsq(a, b):
40
+ # u, s, v = Svd(a)
41
+ # x = (v.T / s).dot(u.T.dot(b))
42
+ # residuals = NotImplementedError # ch_sum((a.dot(x) - b)**2, axis=0)
43
+ # rank = NotImplementedError
44
+ # s = NotImplementedError
45
+ # return x, residuals, rank, s
46
+
47
+ def lstsq(a, b, rcond=-1):
48
+ if rcond != -1:
49
+ raise Exception('non-default rcond not yet implemented')
50
+
51
+ x = Ch(lambda a, b : pinv(a).dot(b))
52
+ x.a = a
53
+ x.b = b
54
+ residuals = ch_sum( (x.a.dot(x) - x.b) **2 , axis=0)
55
+ rank = NotImplementedError
56
+ s = NotImplementedError
57
+
58
+ return x, residuals, rank, s
59
+
60
+ def Svd(x, full_matrices=0, compute_uv=1):
61
+
62
+ if full_matrices != 0:
63
+ raise Exception('full_matrices must be 0')
64
+ if compute_uv != 1:
65
+ raise Exception('compute_uv must be 1')
66
+
67
+ need_transpose = x.shape[0] < x.shape[1]
68
+
69
+ if need_transpose:
70
+ x = x.T
71
+
72
+ svd_d = SvdD(x=x)
73
+ svd_v = SvdV(x=x, svd_d=svd_d)
74
+ svd_u = SvdU(x=x, svd_d=svd_d, svd_v=svd_v)
75
+
76
+ if need_transpose:
77
+ return svd_v, svd_d, svd_u.T
78
+ else:
79
+ return svd_u, svd_d, svd_v.T
80
+
81
+
82
+ class Pinv(Ch):
83
+ dterms = 'mtx'
84
+
85
+ def on_changed(self, which):
86
+ mtx = self.mtx
87
+ if mtx.shape[1] > mtx.shape[0]:
88
+ result = mtx.T.dot(Inv(mtx.dot(mtx.T)))
89
+ else:
90
+ result = Inv(mtx.T.dot(mtx)).dot(mtx.T)
91
+ self._result = result
92
+
93
+ def compute_r(self):
94
+ return self._result.r
95
+
96
+ def compute_dr_wrt(self, wrt):
97
+ if wrt is self.mtx:
98
+ return self._result.dr_wrt(self.mtx)
99
+
100
+ # Couldn't make the SVD version of pinv work yet...
101
+ #
102
+ # class Pinv(Ch):
103
+ # dterms = 'mtx'
104
+ #
105
+ # def on_changed(self, which):
106
+ # u, s, v = Svd(self.mtx)
107
+ # result = (v.T * (NanDivide(1.,row(s)))).dot(u.T)
108
+ # self.add_dterm('_result', result)
109
+ #
110
+ # def compute_r(self):
111
+ # return self._result.r
112
+ #
113
+ # def compute_dr_wrt(self, wrt):
114
+ # if wrt is self._result:
115
+ # return 1
116
+
117
+
118
+
119
+ class LogAbsDet(Ch):
120
+ dterms = 'x'
121
+
122
+ def on_changed(self, which):
123
+ self.sign, self.slogdet = np.linalg.slogdet(self.x.r)
124
+
125
+ def compute_r(self):
126
+ return self.slogdet
127
+
128
+ def compute_dr_wrt(self, wrt):
129
+ if wrt is self.x:
130
+ return row(np.linalg.inv(self.x.r).T)
131
+
132
+ class SignLogAbsDet(Ch):
133
+ dterms = 'logabsdet',
134
+
135
+ def compute_r(self):
136
+ _ = self.logabsdet.r
137
+ return self.logabsdet.sign
138
+
139
+ def compute_dr_wrt(self, wrt):
140
+ return None
141
+
142
+
143
+ class Det(Ch):
144
+ dterms = 'x'
145
+
146
+ def compute_r(self):
147
+ return np.linalg.det(self.x.r)
148
+
149
+ def compute_dr_wrt(self, wrt):
150
+ if wrt is self.x:
151
+ return row(self.r * np.linalg.inv(self.x.r).T)
152
+
153
+
154
+ class Inv(Ch):
155
+ dterms = 'a'
156
+
157
+ def compute_r(self):
158
+ return np.linalg.inv(self.a.r)
159
+
160
+ def compute_dr_wrt(self, wrt):
161
+ if wrt is not self.a:
162
+ return None
163
+
164
+ Ainv = self.r
165
+
166
+ if Ainv.ndim <= 2:
167
+ return -np.kron(Ainv, Ainv.T)
168
+ else:
169
+ Ainv = np.reshape(Ainv, (-1, Ainv.shape[-2], Ainv.shape[-1]))
170
+ AinvT = np.rollaxis(Ainv, -1, -2)
171
+ AinvT = np.reshape(AinvT, (-1, AinvT.shape[-2], AinvT.shape[-1]))
172
+ result = np.dstack([-np.kron(Ainv[i], AinvT[i]).T for i in range(Ainv.shape[0])]).T
173
+ result = sp.block_diag(result)
174
+
175
+ return result
176
+
177
+
178
+ class SvdD(Ch):
179
+ dterms = 'x'
180
+
181
+ @depends_on('x')
182
+ def UDV(self):
183
+ result = np.linalg.svd(self.x.r, full_matrices=False)
184
+ result = [result[0], result[1], result[2].T]
185
+ result[1][np.abs(result[1]) < np.spacing(1)] = 0.
186
+ return result
187
+
188
+ def compute_r(self):
189
+ return self.UDV[1]
190
+
191
+ def compute_dr_wrt(self, wrt):
192
+ if wrt is not self.x:
193
+ return
194
+
195
+ u, d, v = self.UDV
196
+ shp = self.x.r.shape
197
+ u = u[:shp[0], :shp[1]]
198
+ v = v[:shp[1], :d.size]
199
+
200
+ result = np.einsum('ik,jk->kij', u, v)
201
+ result = result.reshape((result.shape[0], -1))
202
+ return result
203
+
204
+
205
+ class SvdV(Ch):
206
+ terms = 'svd_d'
207
+ dterms = 'x'
208
+
209
+ def compute_r(self):
210
+ return self.svd_d.UDV[2]
211
+
212
+ def compute_dr_wrt(self, wrt):
213
+ if wrt is not self.x:
214
+ return
215
+
216
+ U,_D,V = self.svd_d.UDV
217
+
218
+ shp = self.svd_d.x.r.shape
219
+ mxsz = max(shp[0], shp[1])
220
+ #mnsz = min(shp[0], shp[1])
221
+ D = np.zeros(mxsz)
222
+ D[:_D.size] = _D
223
+
224
+ omega = np.zeros((shp[0], shp[1], shp[1], shp[1]))
225
+
226
+ M = shp[0]
227
+ N = shp[1]
228
+
229
+ assert(M >= N)
230
+
231
+ for i in range(shp[0]):
232
+ for j in range(shp[1]):
233
+ for k in range(N):
234
+ for l in range(k+1, N):
235
+ mtx = np.array([
236
+ [D[l],D[k]],
237
+ [D[k],D[l]]])
238
+
239
+ rhs = np.array([U[i,k]*V[j,l], -U[i,l]*V[j,k]])
240
+ result = np.linalg.solve(mtx, rhs)
241
+
242
+ omega[i,j,k,l] = result[1]
243
+ omega[i,j,l,k] = -result[1]
244
+
245
+ #print 'v size is %s' % (str(V.shape),)
246
+ #print 'v omega size is %s' % (str(omega.shape),)
247
+ assert(V.shape[1] == omega.shape[2])
248
+ return np.einsum('ak,ijkl->alij', -V, omega).reshape((self.r.size, wrt.r.size))
249
+
250
+
251
+ class SvdU(Ch):
252
+ dterms = 'x'
253
+ terms = 'svd_d', 'svd_v'
254
+
255
+ def compute_r(self):
256
+ return self.svd_d.UDV[0]
257
+
258
+ def compute_dr_wrt(self, wrt):
259
+ if wrt is self.x:
260
+ # return (
261
+ # self.svd_d.x.dot(self.svd_v)
262
+ # /
263
+ # self.svd_d.reshape((1,-1))
264
+ # ).dr_wrt(self.svd_d.x)
265
+ return (
266
+ NanDivide(
267
+ self.svd_d.x.dot(self.svd_v),
268
+ self.svd_d.reshape((1,-1)))
269
+ ).dr_wrt(self.svd_d.x)
270
+
271
+
272
+ inv = Inv
273
+ svd = Svd
274
+ det = Det
275
+ pinv = Pinv
276
+
277
+ def slogdet(*args):
278
+ n = len(args)
279
+ if n == 1:
280
+ r2 = LogAbsDet(x=args[0])
281
+ r1 = SignLogAbsDet(r2)
282
+ return r1, r2
283
+ else:
284
+ r2 = [LogAbsDet(x=arg) for arg in args]
285
+ r1 = [SignLogAbsDet(r) for r in r2]
286
+ r2 = ch_concatenate(r2)
287
+ return r1, r2
288
+
289
+ def main():
290
+
291
+ tmp = ch_random_randn(100).reshape((10,10))
292
+ print('chumpy version: ' + str(slogdet(tmp)[1].r))
293
+ print('old version:' + str(np.linalg.slogdet(tmp.r)[1]))
294
+
295
+ eps = 1e-10
296
+ diff = np.random.rand(100) * eps
297
+ diff_reshaped = diff.reshape((10,10))
298
+ print(np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1])
299
+ print(slogdet(tmp)[1].dr_wrt(tmp).dot(diff))
300
+
301
+ print(np.linalg.slogdet(tmp.r)[0])
302
+ print(slogdet(tmp)[0])
303
+
304
+ if __name__ == '__main__':
305
+ main()
306
+
vendor/chumpy/chumpy/logic.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author(s): Matthew Loper
3
+
4
+ See LICENCE.txt for licensing and contact information.
5
+ """
6
+
7
+ __author__ = 'matt'
8
+
9
+
10
+ __all__ = [] # added to incrementally below
11
+
12
+ from . import ch
13
+ from .ch import Ch
14
+ import numpy as np
15
+
16
+ class LogicFunc(Ch):
17
+ dterms = 'a' # we keep this here so that changes to children of "a" will trigger cache changes
18
+ terms = 'args', 'kwargs', 'funcname'
19
+
20
+ def compute_r(self):
21
+ arr = self.a
22
+ fn = getattr(np, self.funcname)
23
+ return fn(arr, *self.args, **self.kwargs)
24
+
25
+ def compute_dr_wrt(self, wrt):
26
+ pass
27
+
28
+
29
+ unaries = 'all', 'any', 'isfinite', 'isinf', 'isnan', 'isneginf', 'isposinf', 'logical_not'
30
+ for unary in unaries:
31
+ exec("def %s(a, *args, **kwargs): return LogicFunc(a=a, args=args, kwargs=kwargs, funcname='%s')" % (unary, unary))
32
+ __all__ += unaries
33
+
34
+
35
+
36
+ if __name__ == '__main__':
37
+ from . import ch
38
+ print(all(np.array([1,2,3])))
39
+ print(isinf(np.array([0,2,3])))
vendor/chumpy/chumpy/monitor.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Logging service for tracking dr tree changes from root objective
3
+ and record every step that incrementally changes the dr tree
4
+
5
+ '''
6
+ import os, sys, time
7
+ import json
8
+ import psutil
9
+
10
+ import scipy.sparse as sp
11
+ import numpy as np
12
+ from . import reordering
13
+
14
+ _TWO_20 = float(2 **20)
15
+
16
+ '''
17
+ memory utils
18
+
19
+ '''
20
+ def pdb_mem():
21
+ from .monitor import get_current_memory
22
+ mem = get_current_memory()
23
+ if mem > 7000:
24
+ import pdb;pdb.set_trace()
25
+
26
+ def get_peak_mem():
27
+ '''
28
+ this returns peak memory use since process starts till the moment its called
29
+ '''
30
+ import resource
31
+ rusage_denom = 1024.
32
+ if sys.platform == 'darwin':
33
+ # ... it seems that in OSX the output is different units ...
34
+ rusage_denom = rusage_denom * rusage_denom
35
+ mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / rusage_denom
36
+ return mem
37
+
38
+ def get_current_memory():
39
+ p = psutil.Process(os.getpid())
40
+ mem = p.memory_info()[0]/_TWO_20
41
+
42
+ return mem
43
+
44
+ '''
45
+ Helper for Profiler
46
+ '''
47
+
48
+ def build_cache_info(k, v, info_dict):
49
+ if v is not None:
50
+ issparse = sp.issparse(v)
51
+ size = v.size
52
+ if issparse:
53
+ nonzero = len(v.data)
54
+ else:
55
+ nonzero = np.count_nonzero(v)
56
+ info_dict[k.short_name] = {
57
+ 'sparse': issparse,
58
+ 'size' : str(size),
59
+ 'nonzero' : nonzero,
60
+ }
61
+
62
+
63
+ def cache_info(ch_node):
64
+ result = {}
65
+ if isinstance(ch_node, reordering.Concatenate) and hasattr(ch_node, 'dr_cached') and len(ch_node.dr_cached) > 0:
66
+ for k, v in ch_node.dr_cached.items():
67
+ build_cache_info(k, v, result)
68
+ elif len(ch_node._cache['drs']) > 0:
69
+ for k, v in ch_node._cache['drs'].items():
70
+ build_cache_info(k, v, result)
71
+
72
+ return result
73
+
74
+ class DrWrtProfiler(object):
75
+ base_path = os.path.abspath('profiles')
76
+
77
+ def __init__(self, root, base_path=None):
78
+ self.root = root.obj
79
+ self.history = []
80
+
81
+ ts = time.time()
82
+ if base_path:
83
+ self.base_path = base_path
84
+
85
+ self.path = os.path.join(self.base_path, 'profile_%s.json' % str(ts))
86
+ self.root_path = os.path.join(self.base_path, 'root_%s.json' % str(ts))
87
+
88
+
89
+ with open(self.root_path, 'w') as f:
90
+ json.dump(self.dump_tree(self.root), f, indent=4)
91
+
92
+ def dump_tree(self, node):
93
+ if not hasattr(node, 'dterms'):
94
+ return []
95
+
96
+ node_dict = self.serialize_node(node, verbose=False)
97
+ if hasattr(node, 'visited') and node.visited:
98
+ node_dict.update({'indirect':True})
99
+ return node_dict
100
+
101
+ node.visited = True
102
+ children_list = []
103
+ for dterm in node.dterms:
104
+ if hasattr(node, dterm):
105
+ child = getattr(node, dterm)
106
+ if hasattr(child, 'dterms') or hasattr(child, 'terms'):
107
+ children_list.append(self.dump_tree(child))
108
+ node_dict.update({'children':children_list})
109
+ return node_dict
110
+
111
+ def serialize_node(self, ch_node, verbose=True):
112
+ node_id = id(ch_node)
113
+ name = ch_node.short_name
114
+ ts = time.time()
115
+ status = ch_node._status
116
+ mem = get_current_memory()
117
+ node_cache_info = cache_info(ch_node)
118
+
119
+ rec = {
120
+ 'id': str(node_id),
121
+ 'indirect' : False,
122
+ }
123
+ if verbose:
124
+ rec.update({
125
+ 'name':name,
126
+ 'ts' : ts,
127
+ 'status':status,
128
+ 'mem': mem,
129
+ 'cache': node_cache_info,
130
+ })
131
+ return rec
132
+
133
+ def show_tree(self, label):
134
+ '''
135
+ show tree from the root node
136
+ '''
137
+ self.root.show_tree_cache(label)
138
+
139
+ def record(self, ch_node):
140
+ '''
141
+ Incremental changes
142
+ '''
143
+ rec = self.serialize_node(ch_node)
144
+ self.history.append(rec)
145
+
146
+ def harvest(self):
147
+ print('collecting and dump to file %s' % self.path)
148
+ with open(self.path, 'w') as f:
149
+ json.dump(self.history, f, indent=4)
vendor/chumpy/chumpy/np_tensordot.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Up to numpy 1.13, the numpy implementation of tensordot could be
2
+ # reinterpreted using chumpy. With numpy 1.14 the implementation started using
3
+ # ufunc.multiply.reduce which can't be understood by chumpy. This is the
4
+ # chumpy-compatible implementation of tensodrot from numpy 1.13.3.
5
+ #
6
+ # i.e.
7
+ #
8
+ # import inspect
9
+ # with open('np_tensordot.py', 'w') as f:
10
+ # f.write(''.join(inspect.getsourcelines(np.tensordot)[0]))
11
+
12
+ """
13
+ Copyright (c) 2005-2017, NumPy Developers.
14
+ All rights reserved.
15
+
16
+ Redistribution and use in source and binary forms, with or without
17
+ modification, are permitted provided that the following conditions are
18
+ met:
19
+
20
+ * Redistributions of source code must retain the above copyright
21
+ notice, this list of conditions and the following disclaimer.
22
+
23
+ * Redistributions in binary form must reproduce the above
24
+ copyright notice, this list of conditions and the following
25
+ disclaimer in the documentation and/or other materials provided
26
+ with the distribution.
27
+
28
+ * Neither the name of the NumPy Developers nor the names of any
29
+ contributors may be used to endorse or promote products derived
30
+ from this software without specific prior written permission.
31
+
32
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
33
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
35
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
36
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
37
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
38
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
39
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
40
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
41
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
42
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43
+ """
44
+
45
+ def tensordot(a, b, axes=2):
46
+ """
47
+ Compute tensor dot product along specified axes for arrays >= 1-D.
48
+
49
+ Given two tensors (arrays of dimension greater than or equal to one),
50
+ `a` and `b`, and an array_like object containing two array_like
51
+ objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
52
+ elements (components) over the axes specified by ``a_axes`` and
53
+ ``b_axes``. The third argument can be a single non-negative
54
+ integer_like scalar, ``N``; if it is such, then the last ``N``
55
+ dimensions of `a` and the first ``N`` dimensions of `b` are summed
56
+ over.
57
+
58
+ Parameters
59
+ ----------
60
+ a, b : array_like, len(shape) >= 1
61
+ Tensors to "dot".
62
+
63
+ axes : int or (2,) array_like
64
+ * integer_like
65
+ If an int N, sum over the last N axes of `a` and the first N axes
66
+ of `b` in order. The sizes of the corresponding axes must match.
67
+ * (2,) array_like
68
+ Or, a list of axes to be summed over, first sequence applying to `a`,
69
+ second to `b`. Both elements array_like must be of the same length.
70
+
71
+ See Also
72
+ --------
73
+ dot, einsum
74
+
75
+ Notes
76
+ -----
77
+ Three common use cases are:
78
+ * ``axes = 0`` : tensor product :math:`a\\otimes b`
79
+ * ``axes = 1`` : tensor dot product :math:`a\\cdot b`
80
+ * ``axes = 2`` : (default) tensor double contraction :math:`a:b`
81
+
82
+ When `axes` is integer_like, the sequence for evaluation will be: first
83
+ the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
84
+ Nth axis in `b` last.
85
+
86
+ When there is more than one axis to sum over - and they are not the last
87
+ (first) axes of `a` (`b`) - the argument `axes` should consist of
88
+ two sequences of the same length, with the first axis to sum over given
89
+ first in both sequences, the second axis second, and so forth.
90
+
91
+ Examples
92
+ --------
93
+ A "traditional" example:
94
+
95
+ >>> a = np.arange(60.).reshape(3,4,5)
96
+ >>> b = np.arange(24.).reshape(4,3,2)
97
+ >>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
98
+ >>> c.shape
99
+ (5, 2)
100
+ >>> c
101
+ array([[ 4400., 4730.],
102
+ [ 4532., 4874.],
103
+ [ 4664., 5018.],
104
+ [ 4796., 5162.],
105
+ [ 4928., 5306.]])
106
+ >>> # A slower but equivalent way of computing the same...
107
+ >>> d = np.zeros((5,2))
108
+ >>> for i in range(5):
109
+ ... for j in range(2):
110
+ ... for k in range(3):
111
+ ... for n in range(4):
112
+ ... d[i,j] += a[k,n,i] * b[n,k,j]
113
+ >>> c == d
114
+ array([[ True, True],
115
+ [ True, True],
116
+ [ True, True],
117
+ [ True, True],
118
+ [ True, True]], dtype=bool)
119
+
120
+ An extended example taking advantage of the overloading of + and \\*:
121
+
122
+ >>> a = np.array(range(1, 9))
123
+ >>> a.shape = (2, 2, 2)
124
+ >>> A = np.array(('a', 'b', 'c', 'd'), dtype=object)
125
+ >>> A.shape = (2, 2)
126
+ >>> a; A
127
+ array([[[1, 2],
128
+ [3, 4]],
129
+ [[5, 6],
130
+ [7, 8]]])
131
+ array([[a, b],
132
+ [c, d]], dtype=object)
133
+
134
+ >>> np.tensordot(a, A) # third argument default is 2 for double-contraction
135
+ array([abbcccdddd, aaaaabbbbbbcccccccdddddddd], dtype=object)
136
+
137
+ >>> np.tensordot(a, A, 1)
138
+ array([[[acc, bdd],
139
+ [aaacccc, bbbdddd]],
140
+ [[aaaaacccccc, bbbbbdddddd],
141
+ [aaaaaaacccccccc, bbbbbbbdddddddd]]], dtype=object)
142
+
143
+ >>> np.tensordot(a, A, 0) # tensor product (result too long to incl.)
144
+ array([[[[[a, b],
145
+ [c, d]],
146
+ ...
147
+
148
+ >>> np.tensordot(a, A, (0, 1))
149
+ array([[[abbbbb, cddddd],
150
+ [aabbbbbb, ccdddddd]],
151
+ [[aaabbbbbbb, cccddddddd],
152
+ [aaaabbbbbbbb, ccccdddddddd]]], dtype=object)
153
+
154
+ >>> np.tensordot(a, A, (2, 1))
155
+ array([[[abb, cdd],
156
+ [aaabbbb, cccdddd]],
157
+ [[aaaaabbbbbb, cccccdddddd],
158
+ [aaaaaaabbbbbbbb, cccccccdddddddd]]], dtype=object)
159
+
160
+ >>> np.tensordot(a, A, ((0, 1), (0, 1)))
161
+ array([abbbcccccddddddd, aabbbbccccccdddddddd], dtype=object)
162
+
163
+ >>> np.tensordot(a, A, ((2, 1), (1, 0)))
164
+ array([acccbbdddd, aaaaacccccccbbbbbbdddddddd], dtype=object)
165
+
166
+ """
167
+ try:
168
+ iter(axes)
169
+ except:
170
+ axes_a = list(range(-axes, 0))
171
+ axes_b = list(range(0, axes))
172
+ else:
173
+ axes_a, axes_b = axes
174
+ try:
175
+ na = len(axes_a)
176
+ axes_a = list(axes_a)
177
+ except TypeError:
178
+ axes_a = [axes_a]
179
+ na = 1
180
+ try:
181
+ nb = len(axes_b)
182
+ axes_b = list(axes_b)
183
+ except TypeError:
184
+ axes_b = [axes_b]
185
+ nb = 1
186
+
187
+ a, b = asarray(a), asarray(b)
188
+ as_ = a.shape
189
+ nda = a.ndim
190
+ bs = b.shape
191
+ ndb = b.ndim
192
+ equal = True
193
+ if na != nb:
194
+ equal = False
195
+ else:
196
+ for k in range(na):
197
+ if as_[axes_a[k]] != bs[axes_b[k]]:
198
+ equal = False
199
+ break
200
+ if axes_a[k] < 0:
201
+ axes_a[k] += nda
202
+ if axes_b[k] < 0:
203
+ axes_b[k] += ndb
204
+ if not equal:
205
+ raise ValueError("shape-mismatch for sum")
206
+
207
+ # Move the axes to sum over to the end of "a"
208
+ # and to the front of "b"
209
+ notin = [k for k in range(nda) if k not in axes_a]
210
+ newaxes_a = notin + axes_a
211
+ N2 = 1
212
+ for axis in axes_a:
213
+ N2 *= as_[axis]
214
+ newshape_a = (-1, N2)
215
+ olda = [as_[axis] for axis in notin]
216
+
217
+ notin = [k for k in range(ndb) if k not in axes_b]
218
+ newaxes_b = axes_b + notin
219
+ N2 = 1
220
+ for axis in axes_b:
221
+ N2 *= bs[axis]
222
+ newshape_b = (N2, -1)
223
+ oldb = [bs[axis] for axis in notin]
224
+
225
+ at = a.transpose(newaxes_a).reshape(newshape_a)
226
+ bt = b.transpose(newaxes_b).reshape(newshape_b)
227
+ res = dot(at, bt)
228
+ return res.reshape(olda + oldb)
vendor/chumpy/chumpy/optimization.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+ __all__ = ['minimize']
10
+
11
+ import numpy as np
12
+ from . import ch
13
+ import scipy.sparse as sp
14
+ import scipy.optimize
15
+
16
+ from .optimization_internal import minimize_dogleg
17
+
18
+ #from memory_profiler import profile, memory_usage
19
+
20
+ # def disable_cache_for_single_parent_node(node):
21
+ # if hasattr(node, '_parents') and len(node._parents.keys()) == 1:
22
+ # node.want_cache = False
23
+
24
+
25
+ # Nelder-Mead
26
+ # Powell
27
+ # CG
28
+ # BFGS
29
+ # Newton-CG
30
+ # Anneal
31
+ # L-BFGS-B
32
+ # TNC
33
+ # COBYLA
34
+ # SLSQP
35
+ # dogleg
36
+ # trust-ncg
37
+ def minimize(fun, x0, method='dogleg', bounds=None, constraints=(), tol=None, callback=None, options=None):
38
+
39
+ if method == 'dogleg':
40
+ if options is None: options = {}
41
+ return minimize_dogleg(fun, free_variables=x0, on_step=callback, **options)
42
+
43
+ if isinstance(fun, list) or isinstance(fun, tuple):
44
+ fun = ch.concatenate([f.ravel() for f in fun])
45
+ if isinstance(fun, dict):
46
+ fun = ch.concatenate([f.ravel() for f in list(fun.values())])
47
+ obj = fun
48
+ free_variables = x0
49
+
50
+
51
+ from .ch import SumOfSquares
52
+
53
+ hessp = None
54
+ hess = None
55
+ if obj.size == 1:
56
+ obj_scalar = obj
57
+ else:
58
+ obj_scalar = SumOfSquares(obj)
59
+
60
+ def hessp(vs, p,obj, obj_scalar, free_variables):
61
+ changevars(vs,obj,obj_scalar,free_variables)
62
+ if not hasattr(hessp, 'vs'):
63
+ hessp.vs = vs*0+1e16
64
+ if np.max(np.abs(vs-hessp.vs)) > 0:
65
+
66
+ J = ns_jacfunc(vs,obj,obj_scalar,free_variables)
67
+ hessp.J = J
68
+ hessp.H = 2. * J.T.dot(J)
69
+ hessp.vs = vs
70
+ return np.array(hessp.H.dot(p)).ravel()
71
+ #return 2*np.array(hessp.J.T.dot(hessp.J.dot(p))).ravel()
72
+
73
+ if method.lower() != 'newton-cg':
74
+ def hess(vs, obj, obj_scalar, free_variables):
75
+ changevars(vs,obj,obj_scalar,free_variables)
76
+ if not hasattr(hessp, 'vs'):
77
+ hessp.vs = vs*0+1e16
78
+ if np.max(np.abs(vs-hessp.vs)) > 0:
79
+ J = ns_jacfunc(vs,obj,obj_scalar,free_variables)
80
+ hessp.H = 2. * J.T.dot(J)
81
+ return hessp.H
82
+
83
+ def changevars(vs, obj, obj_scalar, free_variables):
84
+ cur = 0
85
+ changed = False
86
+ for idx, freevar in enumerate(free_variables):
87
+ sz = freevar.r.size
88
+ newvals = vs[cur:cur+sz].copy().reshape(free_variables[idx].shape)
89
+ if np.max(np.abs(newvals-free_variables[idx]).ravel()) > 0:
90
+ free_variables[idx][:] = newvals
91
+ changed = True
92
+
93
+ cur += sz
94
+
95
+ methods_without_callback = ('anneal', 'powell', 'cobyla', 'slsqp')
96
+ if callback is not None and changed and method.lower() in methods_without_callback:
97
+ callback(None)
98
+
99
+ return changed
100
+
101
+ def residuals(vs,obj, obj_scalar, free_variables):
102
+ changevars(vs, obj, obj_scalar, free_variables)
103
+ residuals = obj_scalar.r.ravel()[0]
104
+ return residuals
105
+
106
+ def scalar_jacfunc(vs,obj, obj_scalar, free_variables):
107
+ if not hasattr(scalar_jacfunc, 'vs'):
108
+ scalar_jacfunc.vs = vs*0+1e16
109
+ if np.max(np.abs(vs-scalar_jacfunc.vs)) == 0:
110
+ return scalar_jacfunc.J
111
+
112
+ changevars(vs, obj, obj_scalar, free_variables)
113
+
114
+ if True: # faster, at least on some problems
115
+ result = np.concatenate([np.array(obj_scalar.lop(wrt, np.array([[1]]))).ravel() for wrt in free_variables])
116
+ else:
117
+ jacs = [obj_scalar.dr_wrt(wrt) for wrt in free_variables]
118
+ for idx, jac in enumerate(jacs):
119
+ if sp.issparse(jac):
120
+ jacs[idx] = jacs[idx].todense()
121
+ result = np.concatenate([jac.ravel() for jac in jacs])
122
+
123
+ scalar_jacfunc.J = result
124
+ scalar_jacfunc.vs = vs
125
+ return result.ravel()
126
+
127
+ def ns_jacfunc(vs,obj, obj_scalar, free_variables):
128
+ if not hasattr(ns_jacfunc, 'vs'):
129
+ ns_jacfunc.vs = vs*0+1e16
130
+ if np.max(np.abs(vs-ns_jacfunc.vs)) == 0:
131
+ return ns_jacfunc.J
132
+
133
+ changevars(vs, obj, obj_scalar, free_variables)
134
+ jacs = [obj.dr_wrt(wrt) for wrt in free_variables]
135
+ result = hstack(jacs)
136
+
137
+ ns_jacfunc.J = result
138
+ ns_jacfunc.vs = vs
139
+ return result
140
+
141
+
142
+ x1 = scipy.optimize.minimize(
143
+ method=method,
144
+ fun=residuals,
145
+ callback=callback,
146
+ x0=np.concatenate([free_variable.r.ravel() for free_variable in free_variables]),
147
+ jac=scalar_jacfunc,
148
+ hessp=hessp, hess=hess, args=(obj, obj_scalar, free_variables),
149
+ bounds=bounds, constraints=constraints, tol=tol, options=options).x
150
+
151
+ changevars(x1, obj, obj_scalar, free_variables)
152
+ return free_variables
153
+
154
+
155
+ def main():
156
+ pass
157
+
158
+
159
+ if __name__ == '__main__':
160
+ main()
161
+
vendor/chumpy/chumpy/optimization_internal.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+ import numpy as np
4
+ import scipy.sparse as sp
5
+ from . import ch, utils
6
+ from .ch import pif
7
+ from .utils import timer
8
+
9
+
10
+ def clear_cache_single(node):
11
+ node._cache['drs'].clear()
12
+ if hasattr(node, 'dr_cached'):
13
+ node.dr_cached.clear()
14
+
15
+ def vstack(x):
16
+ x = [a if not isinstance(a, sp.linalg.interface.LinearOperator) else a.dot(np.eye(a.shape[1])) for a in x]
17
+ return sp.vstack(x, format='csc') if any([sp.issparse(a) for a in x]) else np.vstack(x)
18
+ def hstack(x):
19
+ x = [a if not isinstance(a, sp.linalg.interface.LinearOperator) else a.dot(np.eye(a.shape[1])) for a in x]
20
+ return sp.hstack(x, format='csc') if any([sp.issparse(a) for a in x]) else np.hstack(x)
21
+
22
+
23
+ _giter = 0
24
+ class ChInputsStacked(ch.Ch):
25
+ dterms = 'x', 'obj'
26
+ terms = 'free_variables'
27
+
28
+ def compute_r(self):
29
+ if not hasattr(self, 'fevals'):
30
+ self.fevals = 0
31
+ self.fevals += 1
32
+ return self.obj.r.ravel()
33
+
34
+ def dr_wrt(self, wrt, profiler=None):
35
+ '''
36
+ Loop over free variables and delete cache for the whole tree after finished each one
37
+ '''
38
+ if wrt is self.x:
39
+ jacs = []
40
+ for fvi, freevar in enumerate(self.free_variables):
41
+ tm = timer()
42
+ if isinstance(freevar, ch.Select):
43
+ new_jac = self.obj.dr_wrt(freevar.a, profiler=profiler)
44
+ try:
45
+ new_jac = new_jac[:, freevar.idxs]
46
+ except:
47
+ # non-csc sparse matrices may not support column-wise indexing
48
+ new_jac = new_jac.tocsc()[:, freevar.idxs]
49
+ else:
50
+ new_jac = self.obj.dr_wrt(freevar, profiler=profiler)
51
+
52
+ pif('dx wrt {} in {}sec, sparse: {}'.format(freevar.short_name, tm(), sp.issparse(new_jac)))
53
+
54
+ if self._make_dense and sp.issparse(new_jac):
55
+ new_jac = new_jac.todense()
56
+ if self._make_sparse and not sp.issparse(new_jac):
57
+ new_jac = sp.csc_matrix(new_jac)
58
+
59
+ if new_jac is None:
60
+ raise Exception(
61
+ 'Objective has no derivative wrt free variable {}. '
62
+ 'You should likely remove it.'.format(fvi))
63
+
64
+ jacs.append(new_jac)
65
+ tm = timer()
66
+ utils.dfs_do_func_on_graph(self.obj, clear_cache_single)
67
+ pif('dfs_do_func_on_graph in {}sec'.format(tm()))
68
+ tm = timer()
69
+ J = hstack(jacs)
70
+ pif('hstack in {}sec'.format(tm()))
71
+ return J
72
+
73
+ def on_changed(self, which):
74
+ global _giter
75
+ _giter += 1
76
+ if 'x' in which:
77
+ pos = 0
78
+ for idx, freevar in enumerate(self.free_variables):
79
+ sz = freevar.r.size
80
+ rng = np.arange(pos, pos+sz)
81
+ if isinstance(self.free_variables[idx], ch.Select):
82
+ # Deal with nested selects
83
+ selects = []
84
+ a = self.free_variables[idx]
85
+ while isinstance(a, ch.Select):
86
+ selects.append(a.idxs)
87
+ a = a.a
88
+ newv = a.x.copy()
89
+ idxs = selects.pop()
90
+ while len(selects) > 0:
91
+ idxs = idxs[selects.pop()]
92
+ newv.ravel()[idxs] = self.x.r.ravel()[rng]
93
+ a.__setattr__('x', newv, _giter)
94
+ elif isinstance(self.free_variables[idx].x, np.ndarray):
95
+ self.free_variables[idx].__setattr__('x', self.x.r[rng].copy().reshape(self.free_variables[idx].x.shape), _giter)
96
+ else: # a number
97
+ self.free_variables[idx].__setattr__('x', self.x.r[rng], _giter)
98
+ pos += sz
99
+
100
+ @property
101
+ def J(self):
102
+ '''
103
+ Compute Jacobian. Analyze dr graph first to disable unnecessary caching
104
+ '''
105
+ result = self.dr_wrt(self.x, profiler=self.profiler).copy()
106
+ if self.profiler:
107
+ self.profiler.harvest()
108
+ return np.atleast_2d(result) if not sp.issparse(result) else result
109
+
110
+
111
+ def setup_sparse_solver(sparse_solver):
112
+ _solver_fns = {
113
+ 'cg': lambda A, x, M=None : sp.linalg.cg(A, x, M=M, tol=1e-10)[0],
114
+ 'spsolve': lambda A, x : sp.linalg.spsolve(A, x)
115
+ }
116
+ if callable(sparse_solver):
117
+ return sparse_solver
118
+ elif isinstance(sparse_solver, str) and sparse_solver in list(_solver_fns.keys()):
119
+ return _solver_fns[sparse_solver]
120
+ else:
121
+ raise Exception('sparse_solver argument must be either a string in the set (%s) or have the api of scipy.sparse.linalg.spsolve.' % ', '.join(list(_solver_fns.keys())))
122
+
123
+
124
+ def setup_objective(obj, free_variables, on_step=None, disp=True, make_dense=False):
125
+ '''
126
+ obj here can be a list of ch objects or a dict of label: ch objects. Either way, the ch
127
+ objects will be merged into one objective using a ChInputsStacked. The labels are just used
128
+ for printing out values per objective with each iteration. If make_dense is True, the
129
+ resulting object with return a desne Jacobian
130
+ '''
131
+ # Validate free variables
132
+ num_unique_ids = len(np.unique(np.array([id(freevar) for freevar in free_variables])))
133
+ if num_unique_ids != len(free_variables):
134
+ raise Exception('The "free_variables" param contains duplicate variables.')
135
+ # Extract labels
136
+ labels = {}
137
+ if isinstance(obj, list) or isinstance(obj, tuple):
138
+ obj = ch.concatenate([f.ravel() for f in obj])
139
+ elif isinstance(obj, dict):
140
+ labels = obj
141
+ obj = ch.concatenate([f.ravel() for f in list(obj.values())])
142
+ # build objective
143
+ x = np.concatenate([freevar.r.ravel() for freevar in free_variables])
144
+ obj = ChInputsStacked(obj=obj, free_variables=free_variables, x=x, make_dense=make_dense)
145
+ # build callback
146
+ def callback():
147
+ if on_step is not None:
148
+ on_step(obj)
149
+ if disp:
150
+ report_line = ['%.2e' % (np.sum(obj.r**2),)]
151
+ for label, objective in sorted(list(labels.items()), key=lambda x: x[0]):
152
+ report_line.append('%s: %.2e' % (label, np.sum(objective.r**2)))
153
+ report_line = " | ".join(report_line) + '\n'
154
+ sys.stderr.write(report_line)
155
+ return obj, callback
156
+
157
+
158
+ class DoglegState(object):
159
+ '''
160
+ Dogleg preserves a great deal of state from iteration to iteration. Many of the things
161
+ that we need to calculate are dependent only on this state (e.g. the various trust region
162
+ steps, the current jacobian and the A & g that depends on it, etc.). Holding the state and
163
+ the various methods based on that state here allows us to seperate a lot of the jacobian
164
+ based calculation from the flow control of the optmization.
165
+
166
+ There will be once instance of DoglegState per invocation of minimize_dogleg.
167
+ '''
168
+ def __init__(self, delta, solve):
169
+ self.iteration = 0
170
+ self._d_gn = None # gauss-newton
171
+ self._d_sd = None # steepest descent
172
+ self._d_dl = None # dogleg
173
+ self.J = None
174
+ self.A = None
175
+ self.g = None
176
+ self._p = None
177
+ self.delta = delta
178
+ self.solve = solve
179
+ self._r = None
180
+ self.rho = None
181
+ self.done = False
182
+
183
+ @property
184
+ def p(self):
185
+ '''p is the current proposed input vector'''
186
+ return self._p
187
+ @p.setter
188
+ def p(self, val):
189
+ self._p = val.reshape((-1, 1))
190
+
191
+ # induce some certainty about what the shape of the steps are
192
+ @property
193
+ def d_gn(self):
194
+ return self._d_gn
195
+ @d_gn.setter
196
+ def d_gn(self, val):
197
+ if val is not None:
198
+ val = val.reshape((-1, 1))
199
+ self._d_gn = val
200
+
201
+ @property
202
+ def d_sd(self):
203
+ return self._d_sd
204
+ @d_sd.setter
205
+ def d_sd(self, val):
206
+ if val is not None:
207
+ val = val.reshape((-1, 1))
208
+ self._d_sd = val
209
+
210
+ @property
211
+ def d_dl(self):
212
+ return self._d_dl
213
+ @d_dl.setter
214
+ def d_dl(self, val):
215
+ if val is not None:
216
+ val = val.reshape((-1, 1))
217
+ self._d_dl = val
218
+
219
+ @property
220
+ def step(self):
221
+ return self.d_dl.reshape((-1, 1))
222
+ @property
223
+ def step_size(self):
224
+ return np.linalg.norm(self.d_dl)
225
+
226
+ def start_iteration(self):
227
+ self.iteration += 1
228
+ pif('beginning iteration %d' % (self.iteration,))
229
+ self.d_sd = (np.linalg.norm(self.g)**2 / np.linalg.norm(self.J.dot(self.g))**2 * self.g).ravel()
230
+ self.d_gn = None
231
+
232
+ @property
233
+ def r(self):
234
+ '''r is the residual at the current p'''
235
+ return self._r
236
+ @r.setter
237
+ def r(self, val):
238
+ self._r = val.copy().reshape((-1, 1))
239
+ self.updateAg()
240
+
241
+ def updateAg(self):
242
+ tm = timer()
243
+ pif('updating A and g...')
244
+ JT = self.J.T
245
+ self.A = JT.dot(self.J)
246
+ self.g = JT.dot(-self.r).reshape((-1, 1))
247
+ pif('A and g updated in %.2fs' % tm())
248
+
249
+ def update_step(self):
250
+ # if the Cauchy point is outside the trust region,
251
+ # take that direction but only to the edge of the trust region
252
+ if self.delta is not None and np.linalg.norm(self.d_sd) >= self.delta:
253
+ pif('PROGRESS: Using stunted cauchy')
254
+ self.d_dl = np.array(self.delta/np.linalg.norm(self.d_sd) * self.d_sd).ravel()
255
+ else:
256
+ if self.d_gn is None:
257
+ # We only need to compute this once per iteration
258
+ self.updateGN()
259
+ # if the gauss-newton solution is within the trust region, use it
260
+ if self.delta is None or np.linalg.norm(self.d_gn) <= self.delta:
261
+ pif('PROGRESS: Using gauss-newton solution')
262
+ self.d_dl = np.array(self.d_gn).ravel()
263
+ if self.delta is None:
264
+ self.delta = np.linalg.norm(self.d_gn)
265
+ else: # between cauchy step and gauss-newton step
266
+ pif('PROGRESS: between cauchy and gauss-newton')
267
+ # apply step
268
+ self.d_dl = self.d_sd + self.beta_multiplier * (self.d_gn - self.d_sd)
269
+
270
+ @property
271
+ def beta_multiplier(self):
272
+ delta_sq = self.delta**2
273
+ diff = self.d_gn - self.d_sd
274
+ sqnorm_sd = np.linalg.norm(self.d_sd)**2
275
+ pnow = diff.T.dot(diff)*delta_sq + self.d_gn.T.dot(self.d_sd)**2 - np.linalg.norm(self.d_gn)**2 * sqnorm_sd
276
+ return float(delta_sq - sqnorm_sd) / float((diff).T.dot(self.d_sd) + np.sqrt(pnow))
277
+
278
+ def updateGN(self):
279
+ tm = timer()
280
+ if sp.issparse(self.A):
281
+ self.A.eliminate_zeros()
282
+ pif('sparse solve...sparsity infill is %.3f%% (hessian %dx%d)' % (100. * self.A.nnz / (self.A.shape[0] * self.A.shape[1]), self.A.shape[0], self.A.shape[1]))
283
+ if self.g.size > 1:
284
+ self.d_gn = self.solve(self.A, self.g).ravel()
285
+ if np.any(np.isnan(self.d_gn)) or np.any(np.isinf(self.d_gn)):
286
+ from scipy.sparse.linalg import lsqr
287
+ warnings.warn("sparse solve failed, falling back to lsqr")
288
+ self.d_gn = lsqr(self.A, self.g)[0].ravel()
289
+ else:
290
+ self.d_gn = np.atleast_1d(self.g.ravel()[0]/self.A[0,0])
291
+ pif('sparse solve...done in %.2fs' % tm())
292
+ else:
293
+ pif('dense solve...')
294
+ try:
295
+ self.d_gn = np.linalg.solve(self.A, self.g).ravel()
296
+ except Exception:
297
+ warnings.warn("dense solve failed, falling back to lsqr")
298
+ self.d_gn = np.linalg.lstsq(self.A, self.g)[0].ravel()
299
+ pif('dense solve...done in %.2fs' % tm())
300
+
301
+ def updateJ(self, obj):
302
+ tm = timer()
303
+ pif('computing Jacobian...')
304
+ self.J = obj.J
305
+ if self.J is None:
306
+ raise Exception("Computing Jacobian failed!")
307
+ if sp.issparse(self.J):
308
+ tm2 = timer()
309
+ self.J = self.J.tocsr()
310
+ pif('converted to csr in {}secs'.format(tm2()))
311
+ assert(self.J.nnz > 0)
312
+ elif ch.VERBOSE:
313
+ nonzero = np.count_nonzero(self.J)
314
+ pif('Jacobian dense with sparsity %.3f' % (nonzero/self.J.size))
315
+ pif('Jacobian (%dx%d) computed in %.2fs' % (self.J.shape[0], self.J.shape[1], tm()))
316
+ if self.J.shape[1] != self.p.size:
317
+ raise Exception('Jacobian size mismatch with objective input')
318
+ return self.J
319
+
320
+ class Trial(object):
321
+ '''
322
+ Inside each iteration of dogleg we propose a step and check to see if it's actually
323
+ an improvement before we accept it. This class encapsulates that trial and the
324
+ testing to see if it is actually an improvement.
325
+
326
+ There will be one instance of Trial per iteration in dogleg.
327
+ '''
328
+ def __init__(self, proposed_r, state):
329
+ self.r = proposed_r
330
+ self.state = state
331
+ # rho is the ratio of...
332
+ # (improvement in SSE) / (predicted improvement in SSE)
333
+ self.rho = np.linalg.norm(state.r)**2 - np.linalg.norm(proposed_r)**2
334
+ if self.rho > 0:
335
+ with warnings.catch_warnings():
336
+ warnings.filterwarnings('ignore',category=RuntimeWarning)
337
+ predicted_improvement = 2. * state.g.T.dot(state.d_dl) - state.d_dl.T.dot(state.A.dot(state.d_dl))
338
+ self.rho /= predicted_improvement
339
+
340
+ @property
341
+ def is_improvement(self):
342
+ return self.rho > 0
343
+
344
+ @property
345
+ def improvement(self):
346
+ return (np.linalg.norm(self.state.r)**2 - np.linalg.norm(self.r)**2) / np.linalg.norm(self.state.r)**2
347
+
348
+ def trial_r(self, proposed_r):
349
+ return self.Trial(proposed_r, self)
350
+
351
+ def updateRadius(self, rho, lb=.05, ub=.9):
352
+ if rho > ub:
353
+ self.delta = max(self.delta, 2.5*np.linalg.norm(self.d_dl))
354
+ elif rho < lb:
355
+ self.delta *= .25
356
+
357
+
358
+ def minimize_dogleg(obj, free_variables, on_step=None,
359
+ maxiter=200, max_fevals=np.inf, sparse_solver='spsolve',
360
+ disp=True, e_1=1e-15, e_2=1e-15, e_3=0., delta_0=None,
361
+ treat_as_dense=False):
362
+ """"Nonlinear optimization using Powell's dogleg method.
363
+ See Lourakis et al, 2005, ICCV '05, "Is Levenberg-Marquardt the
364
+ Most Efficient Optimization for Implementing Bundle Adjustment?":
365
+ http://www.ics.forth.gr/cvrl/publications/conferences/0201-P0401-lourakis-levenberg.pdf
366
+
367
+ e_N are stopping conditions:
368
+ e_1 is gradient magnatude threshold
369
+ e_2 is step size magnatude threshold
370
+ e_3 is improvement threshold (as a ratio; 0.1 means it must improve by 10%% at each step)
371
+
372
+ maxiter and max_fevals are also stopping conditions. Note that they're not quite the same,
373
+ as an iteration may evaluate the function more than once.
374
+
375
+ sparse_solver is the solver to use to calculate the Gauss-Newton step in the common case
376
+ that the Jacobian is sparse. It can be 'spsolve' (in which case scipy.sparse.linalg.spsolve
377
+ will be used), 'cg' (in which case scipy.sparse.linalg.cg will be used), or any callable
378
+ that matches the api of scipy.sparse.linalg.spsolve to solve `A x = b` for x where A is sparse.
379
+
380
+ cg, uses a Conjugate Gradient method, and will be faster if A is sparse but x is dense.
381
+ spsolve will be faster if x is also sparse.
382
+
383
+ delta_0 defines the initial trust region. Generally speaking, if this is set too low then
384
+ the optimization will never really go anywhere (to small a trust region to make any real
385
+ progress before running out of iterations) and if it's set too high then the optimization
386
+ will diverge immidiately and go wild (such a large trust region that the initial step so
387
+ far overshoots that it can't recover). If it's left as None, it will be automatically
388
+ estimated on the first iteration; it's always updated at each iteration, so this is treated
389
+ only as an initialization.
390
+
391
+ handle_as_dense explicitly converts all Jacobians of obj to dense matrices
392
+ """
393
+
394
+
395
+ solve = setup_sparse_solver(sparse_solver)
396
+ obj, callback = setup_objective(obj, free_variables, on_step=on_step, disp=disp,
397
+ make_dense=treat_as_dense)
398
+
399
+ state = DoglegState(delta=delta_0, solve=solve)
400
+ state.p = obj.x.r
401
+
402
+ #inject profiler if in DEBUG mode
403
+ if ch.DEBUG:
404
+ from .monitor import DrWrtProfiler
405
+ obj.profiler = DrWrtProfiler(obj)
406
+
407
+ callback()
408
+ state.updateJ(obj)
409
+ state.r = obj.r
410
+
411
+ def stop(msg):
412
+ if not state.done:
413
+ pif(msg)
414
+ state.done = True
415
+
416
+ if np.linalg.norm(state.g, np.inf) < e_1:
417
+ stop('stopping because norm(g, np.inf) < %.2e' % e_1)
418
+ while not state.done:
419
+ state.start_iteration()
420
+ while True:
421
+ state.update_step()
422
+ if state.step_size <= e_2 * np.linalg.norm(state.p):
423
+ stop('stopping because of small step size (norm_dl < %.2e)' % (e_2 * np.linalg.norm(state.p)))
424
+ else:
425
+ tm = timer()
426
+ obj.x = state.p + state.step
427
+ trial = state.trial_r(obj.r)
428
+ pif('Residuals computed in %.2fs' % tm())
429
+ # if the objective function improved, update input parameter estimate.
430
+ # Note that the obj.x already has the new parms,
431
+ # and we should not set them again to the same (or we'll bust the cache)
432
+ if trial.is_improvement:
433
+ state.p = state.p + state.step
434
+ callback()
435
+ if e_3 > 0. and trial.improvement < e_3:
436
+ stop('stopping because improvement < %.1e%%' % (100*e_3))
437
+ else:
438
+ state.updateJ(obj)
439
+ state.r = trial.r
440
+ if np.linalg.norm(state.g, np.inf) < e_1:
441
+ stop('stopping because norm(g, np.inf) < %.2e' % e_1)
442
+ else: # Put the old parms back
443
+ obj.x = ch.Ch(state.p)
444
+ obj.on_changed('x') # copies from flat vector to free variables
445
+ # update our trust region
446
+ state.updateRadius(trial.rho)
447
+ if state.delta <= e_2*np.linalg.norm(state.p):
448
+ stop('stopping because trust region is too small')
449
+ if state.done or trial.is_improvement or (obj.fevals >= max_fevals):
450
+ break
451
+ if state.iteration >= maxiter:
452
+ stop('stopping because max number of user-specified iterations (%d) has been met' % maxiter)
453
+ elif obj.fevals >= max_fevals:
454
+ stop('stopping because max number of user-specified func evals (%d) has been met' % max_fevals)
455
+ return obj.free_variables
vendor/chumpy/chumpy/optional_test_performance.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+ import unittest
10
+ import numpy as np
11
+ from functools import reduce
12
+
13
+
14
+ has_ressources = True
15
+ try:
16
+ import resource
17
+
18
+ def abstract_ressource_timer():
19
+ return resource.getrusage(resource.RUSAGE_SELF)
20
+ def abstract_ressource_counter(r1, r2):
21
+ _r1 = r1.ru_stime + r1.ru_utime
22
+ _r2 = r2.ru_stime + r2.ru_utime
23
+
24
+ return _r2 - _r1
25
+ except ImportError:
26
+ has_ressources = False
27
+ pass
28
+
29
+
30
+ if not has_ressources:
31
+ try:
32
+ from ctypes import *
33
+
34
+
35
+
36
+ def abstract_ressource_timer():
37
+ val = c_int64()
38
+ windll.Kernel32.QueryPerformanceCounter(byref(val))
39
+ return val
40
+ def abstract_ressource_counter(r1, r2):
41
+ """Returns the elapsed time between r2 and r1 (r2 > r1) in milliseconds"""
42
+ val = c_int64()
43
+ windll.Kernel32.QueryPerformanceFrequency(byref(val))
44
+
45
+ return (1000*float(r2.value-r1.value))/val.value
46
+
47
+ except ImportError:
48
+ has_win32api = False
49
+
50
+
51
+
52
+
53
+ from . import ch
54
+
55
+
56
+
57
+ class Timer(object):
58
+
59
+ def __enter__(self):
60
+ self.r1 = abstract_ressource_timer()
61
+
62
+ def __exit__(self, exception_type, exception_value, traceback):
63
+ self.r2 = abstract_ressource_timer()
64
+
65
+ self.elapsed = abstract_ressource_counter(self.r1, self.r2)
66
+
67
+ # def timer():
68
+ # tm = resource.getrusage(resource.RUSAGE_SELF)
69
+ # return tm.ru_stime + tm.ru_utime
70
+ #
71
+ # svd1
72
+
73
+ def timer(setup, go, n):
74
+ tms = []
75
+ for i in range(n):
76
+ if setup is not None:
77
+ setup()
78
+
79
+ tm0 = abstract_ressource_timer()
80
+
81
+ # if False:
82
+ # from body.misc.profdot import profdot
83
+ # profdot('go()', globals(), locals())
84
+ # import pdb; pdb.set_trace()
85
+
86
+ go()
87
+ tm1 = abstract_ressource_timer()
88
+
89
+ tms.append(abstract_ressource_counter(tm0, tm1))
90
+
91
+ #raw_input(tms)
92
+ return np.mean(tms) # see docs for timeit, which recommend getting minimum
93
+
94
+
95
+
96
+ import timeit
97
+ class TestPerformance(unittest.TestCase):
98
+
99
+ def setUp(self):
100
+ np.random.seed(0)
101
+ self.mtx_10 = ch.array(np.random.randn(100).reshape((10,10)))
102
+ self.mtx_1k = ch.array(np.random.randn(1000000).reshape((1000,1000)))
103
+
104
+ def compute_binary_ratios(self, vecsize, numvecs):
105
+
106
+ ratio = {}
107
+ for funcname in ['add', 'subtract', 'multiply', 'divide', 'power']:
108
+ for xp in ch, np:
109
+ func = getattr(xp, funcname)
110
+ vecs = [xp.random.rand(vecsize) for i in range(numvecs)]
111
+
112
+ if xp is ch:
113
+ f = reduce(lambda x, y : func(x,y), vecs)
114
+ def go():
115
+ for v in vecs:
116
+ v.x *= -1
117
+ _ = f.r
118
+
119
+ tm_ch = timer(None, go, 10)
120
+ else: # xp is np
121
+ def go():
122
+ for v in vecs:
123
+ v *= -1
124
+ _ = reduce(lambda x, y : func(x,y), vecs)
125
+
126
+ tm_np = timer(None, go, 10)
127
+
128
+ ratio[funcname] = tm_ch / tm_np
129
+
130
+ return ratio
131
+
132
+ def test_binary_ratios(self):
133
+ ratios = self.compute_binary_ratios(vecsize=5000, numvecs=100)
134
+ tol = 1e-1
135
+ self.assertLess(ratios['add'], 8+tol)
136
+ self.assertLess(ratios['subtract'], 8+tol)
137
+ self.assertLess(ratios['multiply'], 8+tol)
138
+ self.assertLess(ratios['divide'], 4+tol)
139
+ self.assertLess(ratios['power'], 2+tol)
140
+ #print ratios
141
+
142
+
143
+ def test_svd(self):
144
+ mtx = ch.array(np.random.randn(100).reshape((10,10)))
145
+
146
+
147
+ # Get times for svd
148
+ from .linalg import svd
149
+ u, s, v = svd(mtx)
150
+ def setup():
151
+ mtx.x = -mtx.x
152
+
153
+ def go_r():
154
+ _ = u.r
155
+ _ = s.r
156
+ _ = v.r
157
+
158
+ def go_dr():
159
+ _ = u.dr_wrt(mtx)
160
+ _ = s.dr_wrt(mtx)
161
+ _ = v.dr_wrt(mtx)
162
+
163
+ cht_r = timer(setup, go_r, 20)
164
+ cht_dr = timer(setup, go_dr, 1)
165
+
166
+ # Get times for numpy svd
167
+ def go():
168
+ u,s,v = np.linalg.svd(mtx.x)
169
+ npt = timer(setup = None, go = go, n = 20)
170
+
171
+ # Compare
172
+ #print cht_r / npt
173
+ #print cht_dr / npt
174
+ self.assertLess(cht_r / npt, 3.3)
175
+ self.assertLess(cht_dr / npt, 2700)
176
+
177
+
178
+
179
+
180
+
181
+ if __name__ == '__main__':
182
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestPerformance)
183
+ unittest.TextTestRunner(verbosity=2).run(suite)
vendor/chumpy/chumpy/reordering.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author(s): Matthew Loper
3
+
4
+ See LICENCE.txt for licensing and contact information.
5
+ """
6
+
7
+ from .ch import Ch
8
+ import numpy as np
9
+ from .utils import row, col
10
+ import scipy.sparse as sp
11
+ import weakref
12
+
13
+ __all__ = ['sort', 'tile', 'repeat', 'transpose', 'rollaxis', 'swapaxes', 'reshape', 'Select',
14
+ 'atleast_1d', 'atleast_2d', 'atleast_3d', 'squeeze', 'expand_dims', 'fliplr', 'flipud',
15
+ 'concatenate', 'vstack', 'hstack', 'dstack', 'ravel', 'diag', 'diagflat', 'roll', 'rot90']
16
+
17
+ # Classes deriving from "Permute" promise to only reorder/reshape
18
+ class Permute(Ch):
19
+ pass
20
+
21
+ def ravel(a, order='C'):
22
+ assert(order=='C')
23
+ if isinstance (a, np.ndarray):
24
+ self = Ch(a)
25
+
26
+ return reshape(a=a, newshape=(-1,))
27
+
28
+ class Reorder(Permute):
29
+ dterms = 'a',
30
+
31
+ def on_changed(self, which):
32
+ if not hasattr(self, 'dr_lookup'):
33
+ self.dr_lookup = {}
34
+
35
+ def compute_r(self):
36
+ return self.reorder(self.a.r)
37
+
38
+ def compute_dr_wrt(self, wrt):
39
+ if wrt is self.a:
40
+ if False:
41
+ from scipy.sparse.linalg.interface import LinearOperator
42
+ return LinearOperator((self.size, wrt.size), lambda x : self.reorder(x.reshape(self.a.shape)).ravel())
43
+ else:
44
+ a = self.a
45
+ asz = a.size
46
+ ashape = a.shape
47
+ key = self.unique_reorder_id()
48
+ if key not in self.dr_lookup or key is None:
49
+ JS = self.reorder(np.arange(asz).reshape(ashape))
50
+ IS = np.arange(JS.size)
51
+ data = np.ones_like(IS)
52
+ shape = JS.shape
53
+ self.dr_lookup[key] = sp.csc_matrix((data, (IS, JS.ravel())), shape=(self.r.size, wrt.r.size))
54
+ return self.dr_lookup[key]
55
+
56
+ class Sort(Reorder):
57
+ dterms = 'a'
58
+ terms = 'axis', 'kind', 'order'
59
+
60
+ def reorder(self, a): return np.sort(a, self.axis, self.kind, self.order)
61
+ def unique_reorder_id(self): return None
62
+
63
+ def sort(a, axis=-1, kind='quicksort', order=None):
64
+ return Sort(a=a, axis=axis, kind=kind, order=order)
65
+
66
+
67
+ class Tile(Reorder):
68
+ dterms = 'a',
69
+ terms = 'reps',
70
+ term_order = 'a', 'reps'
71
+
72
+ def reorder(self, a): return np.tile(a, self.reps)
73
+ def unique_reorder_id(self): return (self.a.shape, tuple(self.reps))
74
+
75
+ def tile(A, reps):
76
+ return Tile(a=A, reps=reps)
77
+
78
+
79
+ class Diag(Reorder):
80
+ dterms = 'a',
81
+ terms = 'k',
82
+
83
+ def reorder(self, a): return np.diag(a, self.k)
84
+ def unique_reorder_id(self): return (self.a.shape, self.k)
85
+
86
+ def diag(v, k=0):
87
+ return Diag(a=v, k=k)
88
+
89
+ class DiagFlat(Reorder):
90
+ dterms = 'a',
91
+ terms = 'k',
92
+
93
+ def reorder(self, a): return np.diagflat(a, self.k)
94
+ def unique_reorder_id(self): return (self.a.shape, self.k)
95
+
96
+ def diagflat(v, k=0):
97
+ return DiagFlat(a=v, k=k)
98
+
99
+
100
+ class Repeat(Reorder):
101
+ dterms = 'a',
102
+ terms = 'repeats', 'axis'
103
+
104
+ def reorder(self, a): return np.repeat(a, self.repeats, self.axis)
105
+ def unique_reorder_id(self): return (self.repeats, self.axis)
106
+
107
+ def repeat(a, repeats, axis=None):
108
+ return Repeat(a=a, repeats=repeats, axis=axis)
109
+
110
+ class transpose(Reorder):
111
+ dterms = 'a'
112
+ terms = 'axes'
113
+ term_order = 'a', 'axes'
114
+
115
+ def reorder(self, a): return np.require(np.transpose(a, axes=self.axes), requirements='C')
116
+ def unique_reorder_id(self): return (self.a.shape, None if self.axes is None else tuple(self.axes))
117
+ def on_changed(self, which):
118
+ if not hasattr(self, 'axes'):
119
+ self.axes = None
120
+ super(self.__class__, self).on_changed(which)
121
+
122
+ class rollaxis(Reorder):
123
+ dterms = 'a'
124
+ terms = 'axis', 'start'
125
+ term_order = 'a', 'axis', 'start'
126
+
127
+ def reorder(self, a): return np.rollaxis(a, axis=self.axis, start=self.start)
128
+ def unique_reorder_id(self): return (self.a.shape, self.axis, self.start)
129
+ def on_changed(self, which):
130
+ if not hasattr(self, 'start'):
131
+ self.start = 0
132
+ super(self.__class__, self).on_changed(which)
133
+
134
+ class swapaxes(Reorder):
135
+ dterms = 'a'
136
+ terms = 'axis1', 'axis2'
137
+ term_order = 'a', 'axis1', 'axis2'
138
+
139
+ def reorder(self, a): return np.swapaxes(a, axis1=self.axis1, axis2=self.axis2)
140
+ def unique_reorder_id(self): return (self.a.shape, self.axis1, self.axis2)
141
+
142
+
143
+
144
+ class Roll(Reorder):
145
+ dterms = 'a',
146
+ terms = 'shift', 'axis'
147
+ term_order = 'a', 'shift', 'axis'
148
+
149
+ def reorder(self, a): return np.roll(a, self.shift, self.axis)
150
+ def unique_reorder_id(self): return (self.shift, self.axis)
151
+
152
+ def roll(a, shift, axis=None):
153
+ return Roll(a, shift, axis)
154
+
155
+ class Rot90(Reorder):
156
+ dterms = 'a',
157
+ terms = 'k',
158
+
159
+ def reorder(self, a): return np.rot90(a, self.k)
160
+ def unique_reorder_id(self): return (self.a.shape, self.k)
161
+
162
+ def rot90(m, k=1):
163
+ return Rot90(a=m, k=k)
164
+
165
+ class Reshape(Permute):
166
+ dterms = 'a',
167
+ terms = 'newshape',
168
+ term_order= 'a', 'newshape'
169
+
170
+ def compute_r(self):
171
+ return self.a.r.reshape(self.newshape)
172
+
173
+ def compute_dr_wrt(self, wrt):
174
+ if wrt is self.a:
175
+ return sp.eye(self.a.size, self.a.size)
176
+ #return self.a.dr_wrt(wrt)
177
+
178
+ # def reshape(a, newshape):
179
+ # if isinstance(a, Reshape) and a.newshape == newshape:
180
+ # return a
181
+ # return Reshape(a=a, newshape=newshape)
182
+ def reshape(a, newshape):
183
+ while isinstance(a, Reshape):
184
+ a = a.a
185
+ return Reshape(a=a, newshape=newshape)
186
+
187
+ # class And(Ch):
188
+ # dterms = 'x1', 'x2'
189
+ #
190
+ # def compute_r(self):
191
+ # if True:
192
+ # needs_work = [self.x1, self.x2]
193
+ # done = []
194
+ # while len(needs_work) > 0:
195
+ # todo = needs_work.pop()
196
+ # if isinstance(todo, And):
197
+ # needs_work += [todo.x1, todo.x2]
198
+ # else:
199
+ # done = [todo] + done
200
+ # return np.concatenate([d.r.ravel() for d in done])
201
+ # else:
202
+ # return np.concatenate((self.x1.r.ravel(), self.x2.r.ravel()))
203
+ #
204
+ # # This is only here for reverse mode to work.
205
+ # # Most of the time, the overridden dr_wrt is callpath gets used.
206
+ # def compute_dr_wrt(self, wrt):
207
+ #
208
+ # if wrt is not self.x1 and wrt is not self.x2:
209
+ # return
210
+ #
211
+ # input_len = wrt.r.size
212
+ # x1_len = self.x1.r.size
213
+ # x2_len = self.x2.r.size
214
+ #
215
+ # mtxs = []
216
+ # if wrt is self.x1:
217
+ # mtxs.append(sp.spdiags(np.ones(x1_len), 0, x1_len, x1_len))
218
+ # else:
219
+ # mtxs.append(sp.csc_matrix((x1_len, input_len)))
220
+ #
221
+ # if wrt is self.x2:
222
+ # mtxs.append(sp.spdiags(np.ones(x2_len), 0, x2_len, x2_len))
223
+ # else:
224
+ # mtxs.append(sp.csc_matrix((x2_len, input_len)))
225
+ #
226
+ #
227
+ # if any([sp.issparse(mtx) for mtx in mtxs]):
228
+ # result = sp.vstack(mtxs, format='csc')
229
+ # else:
230
+ # result = np.vstack(mtxs)
231
+ #
232
+ # return result
233
+ #
234
+ # def dr_wrt(self, wrt, want_stacks=False, reverse_mode=False):
235
+ # self._call_on_changed()
236
+ #
237
+ # input_len = wrt.r.size
238
+ # x1_len = self.x1.r.size
239
+ # x2_len = self.x2.r.size
240
+ #
241
+ # mtxs = []
242
+ # if wrt is self.x1:
243
+ # mtxs.append(sp.spdiags(np.ones(x1_len), 0, x1_len, x1_len))
244
+ # else:
245
+ # if isinstance(self.x1, And):
246
+ # tmp_mtxs = self.x1.dr_wrt(wrt, want_stacks=True, reverse_mode=reverse_mode)
247
+ # for mtx in tmp_mtxs:
248
+ # mtxs.append(mtx)
249
+ # else:
250
+ # mtxs.append(self.x1.dr_wrt(wrt, reverse_mode=reverse_mode))
251
+ # if mtxs[-1] is None:
252
+ # mtxs[-1] = sp.csc_matrix((x1_len, input_len))
253
+ #
254
+ # if wrt is self.x2:
255
+ # mtxs.append(sp.spdiags(np.ones(x2_len), 0, x2_len, x2_len))
256
+ # else:
257
+ # if isinstance(self.x2, And):
258
+ # tmp_mtxs = self.x2.dr_wrt(wrt, want_stacks=True, reverse_mode=reverse_mode)
259
+ # for mtx in tmp_mtxs:
260
+ # mtxs.append(mtx)
261
+ # else:
262
+ # mtxs.append(self.x2.dr_wrt(wrt, reverse_mode=reverse_mode))
263
+ # if mtxs[-1] is None:
264
+ # mtxs[-1] = sp.csc_matrix((x2_len, input_len))
265
+ #
266
+ # if want_stacks:
267
+ # return mtxs
268
+ # else:
269
+ # if any([sp.issparse(mtx) for mtx in mtxs]):
270
+ # result = sp.vstack(mtxs, format='csc')
271
+ # else:
272
+ # result = np.vstack(mtxs)
273
+ #
274
+ # return result
275
+
276
+ class Select(Permute):
277
+ terms = ['idxs', 'preferred_shape']
278
+ dterms = ['a']
279
+ term_order = 'a', 'idxs', 'preferred_shape'
280
+
281
+ def compute_r(self):
282
+ result = self.a.r.ravel()[self.idxs].copy()
283
+ if hasattr(self, 'preferred_shape'):
284
+ return result.reshape(self.preferred_shape)
285
+ else:
286
+ return result
287
+
288
+ def compute_dr_wrt(self, obj):
289
+ if obj is self.a:
290
+ if not hasattr(self, '_dr_cached'):
291
+ IS = np.arange(len(self.idxs))
292
+ JS = self.idxs.ravel()
293
+ ij = np.vstack((row(IS), row(JS)))
294
+ data = np.ones(len(self.idxs))
295
+ self._dr_cached = sp.csc_matrix((data, ij), shape=(len(self.idxs), np.prod(self.a.shape)))
296
+ return self._dr_cached
297
+
298
+ def on_changed(self, which):
299
+ if hasattr(self, '_dr_cached'):
300
+ if 'idxs' in which or self.a.r.size != self._dr_cached.shape[1]:
301
+ del self._dr_cached
302
+
303
+
304
+
305
+ class AtleastNd(Ch):
306
+ dterms = 'x'
307
+ terms = 'ndims'
308
+
309
+ def compute_r(self):
310
+ xr = self.x.r
311
+ if self.ndims == 1:
312
+ target_shape = np.atleast_1d(xr).shape
313
+ elif self.ndims == 2:
314
+ target_shape = np.atleast_2d(xr).shape
315
+ elif self.ndims == 3:
316
+ target_shape = np.atleast_3d(xr).shape
317
+ else:
318
+ raise Exception('Need ndims to be 1, 2, or 3.')
319
+
320
+ return xr.reshape(target_shape)
321
+
322
+ def compute_dr_wrt(self, wrt):
323
+ if wrt is self.x:
324
+ return 1
325
+
326
+ def atleast_nd(ndims, *arys):
327
+ arys = [AtleastNd(x=ary, ndims=ndims) for ary in arys]
328
+ return arys if len(arys) > 1 else arys[0]
329
+
330
+ def atleast_1d(*arys):
331
+ return atleast_nd(1, *arys)
332
+
333
+ def atleast_2d(*arys):
334
+ return atleast_nd(2, *arys)
335
+
336
+ def atleast_3d(*arys):
337
+ return atleast_nd(3, *arys)
338
+
339
+ def squeeze(a, axis=None):
340
+ if isinstance(a, np.ndarray):
341
+ return np.squeeze(a, axis)
342
+ shape = np.squeeze(a.r, axis).shape
343
+ return a.reshape(shape)
344
+
345
+ def expand_dims(a, axis):
346
+ if isinstance(a, np.ndarray):
347
+ return np.expand_dims(a, axis)
348
+ shape = np.expand_dims(a.r, axis).shape
349
+ return a.reshape(shape)
350
+
351
+ def fliplr(m):
352
+ return m[:,::-1]
353
+
354
+ def flipud(m):
355
+ return m[::-1,...]
356
+
357
+ class Concatenate(Ch):
358
+
359
+ def on_changed(self, which):
360
+ if not hasattr(self, 'dr_cached'):
361
+ self.dr_cached = weakref.WeakKeyDictionary()
362
+
363
+ @property
364
+ def our_terms(self):
365
+ if not hasattr(self, '_our_terms'):
366
+ self._our_terms = [getattr(self, s) for s in self.dterms]
367
+ return self._our_terms
368
+
369
+ def __getstate__(self):
370
+ # Have to get rid of WeakKeyDictionaries for serialization
371
+ if hasattr(self, 'dr_cached'):
372
+ del self.dr_cached
373
+ return super(self.__class__, self).__getstate__()
374
+
375
+ def compute_r(self):
376
+ return np.concatenate([t.r for t in self.our_terms], axis=self.axis)
377
+
378
+ @property
379
+ def everything(self):
380
+ if not hasattr(self, '_everything'):
381
+ self._everything = np.arange(self.r.size).reshape(self.r.shape)
382
+ self._everything = np.swapaxes(self._everything, self.axis, 0)
383
+ return self._everything
384
+
385
+ def compute_dr_wrt(self, wrt):
386
+ if not hasattr(self, 'dr_cached'):
387
+ self.dr_cached = weakref.WeakKeyDictionary()
388
+ if wrt in self.dr_cached and self.dr_cached[wrt] is not None:
389
+ return self.dr_cached[wrt]
390
+
391
+ if wrt not in self.our_terms:
392
+ return
393
+
394
+ _JS = np.arange(wrt.size)
395
+ _data = np.ones(wrt.size)
396
+
397
+ IS = []
398
+ JS = []
399
+ data = []
400
+
401
+ offset = 0
402
+ for term in self.our_terms:
403
+ tsz = term.shape[self.axis]
404
+ if term is wrt:
405
+ JS += [_JS]
406
+ data += [_data]
407
+ IS += [np.swapaxes(self.everything[offset:offset+tsz], self.axis, 0).ravel()]
408
+ offset += tsz
409
+ IS = np.concatenate(IS).ravel()
410
+ JS = np.concatenate(JS).ravel()
411
+ data = np.concatenate(data)
412
+
413
+ res = sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.size))
414
+
415
+ if len(list(self._parents.keys())) != 1:
416
+ self.dr_cached[wrt] = res
417
+ else:
418
+ self.dr_cached[wrt] = None
419
+
420
+ return res
421
+
422
+
423
+ def expand_concatenates(mtxs, axis=0):
424
+ mtxs = list(mtxs)
425
+ done = []
426
+ while len(mtxs) > 0:
427
+ mtx = mtxs.pop(0)
428
+ if isinstance(mtx, Concatenate) and mtx.axis == axis:
429
+ mtxs = [getattr(mtx, s) for s in mtx.dterms] + mtxs
430
+ else:
431
+ done.append(mtx)
432
+ return done
433
+
434
+
435
+ def concatenate(mtxs, axis=0, **kwargs):
436
+
437
+ mtxs = expand_concatenates(mtxs, axis)
438
+
439
+ result = Concatenate(**kwargs)
440
+ result.dterms = []
441
+ for i, mtx in enumerate(mtxs):
442
+ result.dterms.append('m%d' % (i,))
443
+ setattr(result, result.dterms[-1], mtx)
444
+ result.axis = axis
445
+ return result
446
+
447
+ def hstack(mtxs, **kwargs):
448
+ return concatenate(mtxs, axis=1, **kwargs)
449
+
450
+ def vstack(mtxs, **kwargs):
451
+ return concatenate([atleast_2d(m) for m in mtxs], axis=0, **kwargs)
452
+
453
+ def dstack(mtxs, **kwargs):
454
+ return concatenate([atleast_3d(m) for m in mtxs], axis=2, **kwargs)
vendor/chumpy/chumpy/test_ch.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+ import time
10
+
11
+ import unittest
12
+ import numpy as np
13
+ import scipy.sparse as sp
14
+
15
+ from . import ch
16
+
17
+ class TestCh(unittest.TestCase):
18
+
19
+
20
+ def test_cachehits(self):
21
+ """Test how many nodes are visited when cache is cleared.
22
+ If the number of hits changes, it has to be carefully
23
+ looked at to make sure that correctness and performance
24
+ don't get messed up by a change."""
25
+
26
+ a = ch.array(1)
27
+ b = ch.array(2)
28
+ c = a
29
+ for i in range(10):
30
+ c = a + c + b
31
+
32
+ c.dr_wrt(a)
33
+ c.dr_wrt(b)
34
+ self.assertEqual(a.clear_cache() + b.clear_cache(), 59)
35
+ c.dr_wrt(a)
36
+ c.dr_wrt(b)
37
+ self.assertEqual(a.clear_cache(123) + b.clear_cache(123), 41)
38
+
39
+ def test_nested_concatenate(self):
40
+ aa = ch.arange(3)
41
+ bb = ch.arange(4)
42
+ cc = ch.arange(5)
43
+
44
+ result = ch.concatenate((ch.concatenate((aa,bb)),cc))
45
+ self.assertTrue(result.m0 is aa)
46
+ self.assertTrue(result.m1 is bb)
47
+ self.assertTrue(result.m2 is cc)
48
+
49
+ self.assertTrue(result.dr_wrt(aa).nnz > 0)
50
+ self.assertTrue(result.dr_wrt(bb).nnz > 0)
51
+ self.assertTrue(result.dr_wrt(cc).nnz > 0)
52
+
53
+ def test_nandivide(self):
54
+ foo = ch.array(np.random.randn(16).reshape((4,4)))
55
+ bar = ch.array(np.random.randn(16).reshape((4,4)))
56
+ bar[2,2] = 0
57
+ self.assertEqual(ch.NanDivide(foo,bar)[2,2].r, 0.)
58
+ foo[2,2] = 0
59
+ self.assertEqual(ch.NanDivide(foo,bar)[2,2].r, 0.)
60
+
61
+ def test_casting(self):
62
+ for fn in float, int:
63
+ self.assertEqual(fn(np.array(5)), fn(ch.array(5)))
64
+ self.assertEqual(fn(np.array([[5]])), fn(ch.array([[5]])))
65
+
66
+ def test_tensordot(self):
67
+ an = np.arange(60.).reshape(3,4,5)
68
+ bn = np.arange(24.).reshape(4,3,2)
69
+ cn = np.tensordot(an,bn, axes=([1,0],[0,1]))
70
+
71
+ ac = ch.arange(60.).reshape(3,4,5)
72
+ bc = ch.arange(24.).reshape(4,3,2)
73
+ cc = ch.tensordot(ac,bc, axes=([1,0],[0,1]))
74
+
75
+ cc.r
76
+ cc.dr_wrt(ac)
77
+ cc.dr_wrt(bc)
78
+ #print cn
79
+
80
+ def test_make_sure_is_double(self):
81
+ x = ch.array([0])
82
+ self.assertTrue(isinstance(x.r[0], np.float64))
83
+
84
+ def test_cross(self):
85
+ aa = ch.random.randn(30).reshape((10,3))
86
+ bb = ch.random.randn(30).reshape((10,3))
87
+
88
+ cross_ch = ch.cross(aa, bb)
89
+ cross_np = np.cross(aa.r, bb.r)
90
+
91
+ # print cross_ch.r
92
+ # print cross_np
93
+
94
+ eps = 1.0
95
+ step = (np.random.rand(30) - .5).reshape((10,3)) * eps
96
+
97
+ gt_diff = np.cross(aa.r, bb.r+step) - cross_np
98
+ pr_diff = cross_ch.dr_wrt(bb).dot(step.ravel())
99
+ # print gt_diff
100
+ # print pr_diff
101
+ # print np.max(np.abs(gt_diff.ravel()-pr_diff.ravel()))
102
+ self.assertTrue(1e-14 > np.max(np.abs(gt_diff.ravel()-pr_diff.ravel())))
103
+
104
+ gt_diff = np.cross(aa.r+step, bb.r) - cross_np
105
+ pr_diff = cross_ch.dr_wrt(aa).dot(step.ravel())
106
+ #print gt_diff
107
+ # print pr_diff
108
+ # print np.max(np.abs(gt_diff.ravel()-pr_diff.ravel()))
109
+ self.assertTrue(1e-14 > np.max(np.abs(gt_diff.ravel()-pr_diff.ravel())))
110
+
111
+ def test_dr_wrt_selection(self):
112
+ aa = ch.arange(10,20)
113
+ bb = ch.arange(1,11)
114
+ cc = aa * bb + aa + bb +2
115
+
116
+ dr0 = cc.dr_wrt(aa[4:6])
117
+ dr1 = cc.dr_wrt(aa)[:,4:6]
118
+ self.assertTrue((dr0 - dr1).nnz == 0)
119
+
120
+ dr0 = cc.dr_wrt(bb[5:8])
121
+ dr1 = cc.dr_wrt(bb)[:,5:8]
122
+ self.assertTrue((dr0 - dr1).nnz == 0)
123
+
124
+
125
+ def test_sum_mean_std_var(self):
126
+ for fn in [ch.sum, ch.mean, ch.var, ch.std]:
127
+
128
+ # Create fake input and differences in input space
129
+ data1 = ch.ones((3,4,7,2))
130
+ data2 = ch.array(data1.r + .1 * np.random.rand(data1.size).reshape(data1.shape))
131
+ diff = data2.r - data1.r
132
+
133
+ # Compute outputs
134
+ result1 = fn(data1, axis=2)
135
+ result2 = fn(data2, axis=2)
136
+
137
+ # Empirical and predicted derivatives
138
+ gt = result2.r - result1.r
139
+ pred = result1.dr_wrt(data1).dot(diff.ravel()).reshape(gt.shape)
140
+
141
+ #print np.max(np.abs(gt - pred))
142
+
143
+ if fn in [ch.std, ch.var]:
144
+ self.assertTrue(1e-2 > np.max(np.abs(gt - pred)))
145
+ else:
146
+ self.assertTrue(1e-14 > np.max(np.abs(gt - pred)))
147
+ # test caching
148
+ dr0 = result1.dr_wrt(data1)
149
+ data1[:] = np.random.randn(data1.size).reshape(data1.shape)
150
+ self.assertTrue(result1.dr_wrt(data1) is dr0) # changing values shouldn't force recompute
151
+ result1.axis=1
152
+ self.assertTrue(result1.dr_wrt(data1) is not dr0)
153
+
154
+ self.assertEqual(ch.mean(ch.eye(3),axis=1).ndim, np.mean(np.eye(3),axis=1).ndim)
155
+ self.assertEqual(ch.mean(ch.eye(3),axis=0).ndim, np.mean(np.eye(3),axis=0).ndim)
156
+ self.assertEqual(ch.sum(ch.eye(3),axis=1).ndim, np.sum(np.eye(3),axis=1).ndim)
157
+ self.assertEqual(ch.sum(ch.eye(3),axis=0).ndim, np.sum(np.eye(3),axis=0).ndim)
158
+
159
+
160
+
161
+ def test_cumsum(self):
162
+ a = ch.array([1.,5.,3.,7.])
163
+ cs = ch.cumsum(a)
164
+ r1 = cs.r
165
+ dr = cs.dr_wrt(a)
166
+ diff = (ch.random.rand(4)-.5)*.1
167
+ a.x += diff.r
168
+ pred = dr.dot(diff.r)
169
+ gt = cs.r - r1
170
+ self.assertTrue(1e-13 > np.max(np.abs(gt - pred)))
171
+
172
+
173
+ def test_iteration_cache(self):
174
+ """ Each time you set an attribute, the cache (of r's and dr's) of
175
+ ancestors is cleared. Because children share ancestors, this means
176
+ these can be cleared multiple times unnecessarily; in some cases,
177
+ where lots of objects exist, this cache clearing can actually be a bottleneck.
178
+
179
+ Therefore, the concept of an iteration was added; intended to be used in
180
+ an optimization setting (see optimization.py) and in the set() method, it
181
+ avoids such redundant clearing of cache."""
182
+
183
+ a, b, c = ch.Ch(1), ch.Ch(2), ch.Ch(3)
184
+ x = a+b
185
+ y = x+c
186
+ self.assertTrue(y.r[0]==6)
187
+
188
+ a.__setattr__('x', 10, 1)
189
+ self.assertTrue(y.r == 15)
190
+ a.__setattr__('x', 100, 1)
191
+ self.assertTrue(y.r == 15)
192
+ a.__setattr__('x', 100, 2)
193
+ self.assertTrue(y.r == 105)
194
+
195
+ a, b, c = ch.array([1]), ch.array([2]), ch.array([3])
196
+ x = a+b
197
+ y = x+c
198
+ self.assertTrue(y.r[0]==6)
199
+
200
+ a.__setattr__('x', np.array([10]), 1)
201
+ self.assertTrue(y.r[0] == 15)
202
+ a.__setattr__('x', np.array(100), 1)
203
+ self.assertTrue(y.r[0] == 15)
204
+ a.__setattr__('x', np.array(100), 2)
205
+ self.assertTrue(y.r[0] == 105)
206
+ a.__setitem__(list(range(0,1)), np.array(200), 2)
207
+ self.assertTrue(y.r[0] == 105)
208
+ a.__setitem__(list(range(0,1)), np.array(200), 3)
209
+ self.assertTrue(y.r[0] == 205)
210
+
211
+
212
+
213
+ def test_stacking(self):
214
+
215
+ a1 = ch.Ch(np.arange(10).reshape(2,5))
216
+ b1 = ch.Ch(np.arange(20).reshape(4,5))
217
+ c1 = ch.vstack((a1,b1))
218
+ c1_check = np.vstack((a1.r, b1.r))
219
+ residuals1 = (c1_check - c1.r).ravel()
220
+
221
+
222
+ a2 = ch.Ch(np.arange(10).reshape(5,2))
223
+ b2 = ch.Ch(np.arange(20).reshape(5,4))
224
+ c2 = ch.hstack((a2,b2))
225
+ c2_check = np.hstack((a2.r, b2.r))
226
+ residuals2 = (c2_check - c2.r).ravel()
227
+
228
+ self.assertFalse(np.any(residuals1))
229
+ self.assertFalse(np.any(residuals2))
230
+
231
+ d0 = ch.array(np.arange(60).reshape((10,6)))
232
+ d1 = ch.vstack((d0[:4], d0[4:]))
233
+ d2 = ch.hstack((d1[:,:3], d1[:,3:]))
234
+ tmp = d2.dr_wrt(d0).todense()
235
+ diff = tmp - np.eye(tmp.shape[0])
236
+ self.assertFalse(np.any(diff.ravel()))
237
+
238
+
239
+
240
+ #def test_drs(self):
241
+ # a = ch.Ch(2)
242
+ # b = ch.Ch(3)
243
+ # c = a * b
244
+ # print c.dr_wrt(a)
245
+ # print c.compute_drs_wrt(a).r
246
+
247
+ @unittest.skip('We are using LinearOperator for this for now. Might change back though.')
248
+ def test_reorder_caching(self):
249
+ a = ch.Ch(np.zeros(8).reshape((4,2)))
250
+ b = a.T
251
+ dr0 = b.dr_wrt(a)
252
+ a.x = a.x + 1.
253
+ dr1 = b.dr_wrt(a)
254
+ self.assertTrue(dr0 is dr1)
255
+ a.x = np.zeros(4).reshape((2,2))
256
+ dr2 = b.dr_wrt(a)
257
+ self.assertTrue(dr2 is not dr1)
258
+
259
+ def test_transpose(self):
260
+ from .utils import row, col
261
+ from copy import deepcopy
262
+ for which in ('C', 'F'): # test in fortran and contiguous mode
263
+ a = ch.Ch(np.require(np.zeros(8).reshape((4,2)), requirements=which))
264
+ b = a.T
265
+
266
+ b1 = b.r.copy()
267
+ #dr = b.dr_wrt(a).copy()
268
+ dr = deepcopy(b.dr_wrt(a))
269
+
270
+ diff = np.arange(a.size).reshape(a.shape)
271
+ a.x = np.require(a.r + diff, requirements=which)
272
+ b2 = b.r.copy()
273
+
274
+ diff_pred = dr.dot(col(diff)).ravel()
275
+ diff_emp = (b2 - b1).ravel()
276
+ np.testing.assert_array_equal(diff_pred, diff_emp)
277
+
278
+
279
+ def test_unary(self):
280
+ fns = [ch.exp, ch.log, ch.sin, ch.arcsin, ch.cos, ch.arccos, ch.tan, ch.arctan, ch.negative, ch.square, ch.sqrt, ch.abs, ch.reciprocal]
281
+
282
+ eps = 1e-8
283
+ for f in fns:
284
+
285
+ x0 = ch.Ch(.25)
286
+ x1 = ch.Ch(x0.r+eps)
287
+
288
+ pred = f(x0).dr_wrt(x0)
289
+ empr = (f(x1).r - f(x0).r) / eps
290
+
291
+ # print pred
292
+ # print empr
293
+ if f is ch.reciprocal:
294
+ self.assertTrue(1e-6 > np.abs(pred.ravel()[0] - empr.ravel()[0]))
295
+ else:
296
+ self.assertTrue(1e-7 > np.abs(pred.ravel()[0] - empr.ravel()[0]))
297
+
298
+
299
+ def test_serialization(self):
300
+ # The main challenge with serialization is the "_parents"
301
+ # attribute, which is a nonserializable WeakKeyDictionary.
302
+ # So we pickle/unpickle, change a child and verify the value
303
+ # at root, and verify that both children have parentage.
304
+ from six.moves import cPickle as pickle
305
+ tmp = ch.Ch(10) + ch.Ch(20)
306
+ tmp = pickle.loads(pickle.dumps(tmp))
307
+ tmp.b.x = 30
308
+ self.assertTrue(tmp.r[0] == 40)
309
+ self.assertTrue(list(tmp.a._parents.keys())[0] == tmp)
310
+ self.assertTrue(list(tmp.a._parents.keys())[0] == list(tmp.b._parents.keys())[0])
311
+
312
+ def test_chlambda1(self):
313
+ c1, c2, c3 = ch.Ch(1), ch.Ch(2), ch.Ch(3)
314
+ adder = ch.ChLambda(lambda x, y: x+y)
315
+ adder.x = c1
316
+ adder.y = c2
317
+ self.assertTrue(adder.r == 3)
318
+ adder.x = c2
319
+ self.assertTrue(adder.r == 4)
320
+ adder.x = c1
321
+ self.assertTrue(adder.r == 3)
322
+
323
+
324
+ def test_chlambda2(self):
325
+ passthrough = ch.ChLambda( lambda x : x)
326
+ self.assertTrue(passthrough.dr_wrt(passthrough.x) is not None)
327
+ passthrough.x = ch.Ch(123)
328
+ self.assertTrue(passthrough.dr_wrt(passthrough.x) is not None)
329
+
330
+ # It's probably not reasonable to expect this
331
+ # to work for ChLambda
332
+ #def test_chlambda3(self):
333
+ # c1, c2, c3 = ch.Ch(1), ch.Ch(2), ch.Ch(3)
334
+ # triple = ch.ChLambda( lambda x, y, z : x(y, z))
335
+ # triple.x = Add
336
+ # triple.y = c2
337
+ # triple.z = c3
338
+
339
+
340
+
341
+
342
+
343
+ def test_amax(self):
344
+ from .ch import amax
345
+ import numpy as np
346
+ arr = np.empty((5,2,3,7))
347
+ arr.flat[:] = np.sin(np.arange(arr.size)*1000.)
348
+ #arr = np.array(np.sin(np.arange(24)*10000.).reshape(2,3,4))
349
+
350
+ for axis in range(len(arr.shape)):
351
+ a = amax(a=arr, axis=axis)
352
+ pred = a.dr_wrt(a.a).dot(arr.ravel())
353
+ real = np.amax(arr, axis=axis).ravel()
354
+ self.assertTrue(np.max(np.abs(pred-real)) < 1e-10)
355
+
356
+ def test_maximum(self):
357
+ from .utils import row, col
358
+ from .ch import maximum
359
+
360
+ # Make sure that when we compare the max of two *identical* numbers,
361
+ # we get the right derivatives wrt both
362
+ the_max = maximum(ch.Ch(1), ch.Ch(1))
363
+ self.assertTrue(the_max.r.ravel()[0] == 1.)
364
+ self.assertTrue(the_max.dr_wrt(the_max.a)[0,0] == 1.)
365
+ self.assertTrue(the_max.dr_wrt(the_max.b)[0,0] == 1.)
366
+
367
+ # Now test given that all numbers are different, by allocating from
368
+ # a pool of randomly permuted numbers.
369
+ # We test combinations of scalars and 2d arrays.
370
+ rnd = np.asarray(np.random.permutation(np.arange(20)), np.float64)
371
+ c1 = ch.Ch(rnd[:6].reshape((2,3)))
372
+ c2 = ch.Ch(rnd[6:12].reshape((2,3)))
373
+ s1 = ch.Ch(rnd[12])
374
+ s2 = ch.Ch(rnd[13])
375
+
376
+ eps = .1
377
+ for first in [c1, s1]:
378
+ for second in [c2, s2]:
379
+ the_max = maximum(first, second)
380
+
381
+ for which_to_change in [first, second]:
382
+
383
+
384
+ max_r0 = the_max.r.copy()
385
+ max_r_diff = np.max(np.abs(max_r0 - np.maximum(first.r, second.r)))
386
+ self.assertTrue(max_r_diff == 0)
387
+ max_dr = the_max.dr_wrt(which_to_change).copy()
388
+ which_to_change.x = which_to_change.x + eps
389
+ max_r1 = the_max.r.copy()
390
+
391
+ emp_diff = (the_max.r - max_r0).ravel()
392
+ pred_diff = max_dr.dot(col(eps*np.ones(max_dr.shape[1]))).ravel()
393
+
394
+ #print 'comparing the following numbers/vectors:'
395
+ #print first.r
396
+ #print second.r
397
+ #print 'empirical vs predicted difference:'
398
+ #print emp_diff
399
+ #print pred_diff
400
+ #print '-----'
401
+
402
+ max_dr_diff = np.max(np.abs(emp_diff-pred_diff))
403
+ #print 'max dr diff: %.2e' % (max_dr_diff,)
404
+ self.assertTrue(max_dr_diff < 1e-14)
405
+
406
+
407
+ def test_shared(self):
408
+
409
+ chs = [ch.Ch(i) for i in range(10)]
410
+ vrs = [float(i) for i in range(10)]
411
+
412
+ func = lambda a : a[0]*a[1] + (a[2]*a[3])/a[4]
413
+
414
+ chained_result = func(chs).r
415
+ regular_result = func(vrs)
416
+
417
+ self.assertTrue(chained_result == regular_result)
418
+ #print chained_result
419
+ #print regular_result
420
+
421
+ chained_func = func(chs)
422
+ chained_func.replace(chs[0], ch.Ch(50))
423
+ vrs[0] = 50
424
+
425
+ chained_result = chained_func.r
426
+ regular_result = func(vrs)
427
+
428
+ self.assertTrue(chained_result == regular_result)
429
+ #print chained_result
430
+ #print regular_result
431
+
432
+
433
+ def test_matmatmult(self):
434
+ from .ch import dot
435
+ mtx1 = ch.Ch(np.arange(6).reshape((3,2)))
436
+ mtx2 = ch.Ch(np.arange(8).reshape((2,4))*10)
437
+
438
+ mtx3 = dot(mtx1, mtx2)
439
+ #print mtx1.r
440
+ #print mtx2.r
441
+ #print mtx3.r
442
+ #print mtx3.dr_wrt(mtx1).todense()
443
+ #print mtx3.dr_wrt(mtx2).todense()
444
+
445
+ for mtx in [mtx1, mtx2]:
446
+ oldval = mtx3.r.copy()
447
+ mtxd = mtx3.dr_wrt(mtx).copy()
448
+ mtx_diff = np.random.rand(mtx.r.size).reshape(mtx.r.shape)
449
+ mtx.x = mtx.r + mtx_diff
450
+ mtx_emp = mtx3.r - oldval
451
+ mtx_pred = mtxd.dot(mtx_diff.ravel()).reshape(mtx_emp.shape)
452
+
453
+ self.assertTrue(np.max(np.abs(mtx_emp - mtx_pred)) < 1e-11)
454
+
455
+
456
+ def test_ndim(self):
457
+ vs = [ch.Ch(np.random.randn(6).reshape(2,3)) for i in range(6)]
458
+ res = vs[0] + vs[1] - vs[2] * vs[3] / (vs[4] ** 2) ** vs[5]
459
+ self.assertTrue(res.shape[0]==2 and res.shape[1]==3)
460
+ res = (vs[0] + 1) + (vs[1] - 2) - (vs[2] * 3) * (vs[3] / 4) / (vs[4] ** 2) ** vs[5]
461
+ self.assertTrue(res.shape[0]==2 and res.shape[1]==3)
462
+ drs = [res.dr_wrt(v) for v in vs]
463
+
464
+
465
+ def test_indexing(self):
466
+ big = ch.Ch(np.arange(60).reshape((10,6)))
467
+ little = big[1:3, 3:6]
468
+ self.assertTrue(np.max(np.abs(little.r - np.array([[9,10,11],[15,16,17]]))) == 0)
469
+
470
+ little = big[5]
471
+ self.assertTrue(np.max(np.abs(little.r - np.arange(30, 36))) == 0)
472
+ self.assertTrue(np.max(np.abs(sp.coo_matrix(little.dr_wrt(big)).col - np.arange(30,36))) == 0)
473
+
474
+ little = big[2, 3]
475
+ self.assertTrue(little.r[0] == 15.0)
476
+
477
+ little = big[2, 3:5]
478
+ self.assertTrue(np.max(np.abs(little.r - np.array([15, 16]))) == 0.)
479
+ _ = little.dr_wrt(big)
480
+
481
+ # Tests assignment through reorderings
482
+ aa = ch.arange(4*4*4).reshape((4,4,4))[:3,:3,:3]
483
+ aa[0,1,2] = 100
484
+ self.assertTrue(aa[0,1,2].r[0] == 100)
485
+
486
+ # Tests assignment through reorderings (NaN's are a special case)
487
+ aa = ch.arange(9).reshape((3,3))
488
+ aa[1,1] = np.nan
489
+ self.assertTrue(np.isnan(aa.r[1,1]))
490
+ self.assertFalse(np.isnan(aa.r[0,0]))
491
+
492
+
493
+ def test_redundancy_removal(self):
494
+
495
+ for MT in [False, True]:
496
+ x1, x2 = ch.Ch(10), ch.Ch(20)
497
+ x1_plus_x2_1 = x1 + x2
498
+ x1_plus_x2_2 = x1 + x2
499
+ redundant_sum = (x1_plus_x2_1 + x1_plus_x2_2) * 2
500
+ redundant_sum.MT = MT
501
+
502
+ self.assertTrue(redundant_sum.a.a is not redundant_sum.a.b)
503
+ redundant_sum.remove_redundancy()
504
+ self.assertTrue(redundant_sum.a.a is redundant_sum.a.b)
505
+
506
+ def test_caching(self):
507
+
508
+ vals = [10, 20, 30, 40, 50]
509
+ f = lambda a, b, c, d, e : a + (b * c) - d ** e
510
+
511
+ # Set up our objects
512
+ Cs = [ch.Ch(v) for v in vals]
513
+ C_result = f(*Cs)
514
+
515
+ # Sometimes residuals should be cached
516
+ r1 = C_result.r
517
+ r2 = C_result.r
518
+ self.assertTrue(r1 is r2)
519
+
520
+ # Other times residuals need refreshing
521
+ Cs[0].set(x=5)
522
+ r3 = C_result.r
523
+ self.assertTrue(r3 is not r2)
524
+
525
+ # Sometimes derivatives should be cached
526
+ dr1 = C_result.dr_wrt(Cs[1])
527
+ dr2 = C_result.dr_wrt(Cs[1])
528
+ self.assertTrue(dr1 is dr2)
529
+
530
+ # Other times derivatives need refreshing
531
+ Cs[2].set(x=5)
532
+ dr3 = C_result.dr_wrt(Cs[1])
533
+ self.assertTrue(dr3 is not dr2)
534
+
535
+
536
+ def test_scalars(self):
537
+
538
+ try:
539
+ import theano.tensor as T
540
+ from theano import function
541
+ except:
542
+ return
543
+
544
+ # Set up variables and function
545
+ vals = [1, 2, 3, 4, 5]
546
+ f = lambda a, b, c, d, e : a + (b * c) - d ** e
547
+
548
+ # Set up our objects
549
+ Cs = [ch.Ch(v) for v in vals]
550
+ C_result = f(*Cs)
551
+
552
+ # Set up Theano's equivalents
553
+ Ts = T.dscalars('T1', 'T2', 'T3', 'T4', 'T5')
554
+ TF = f(*Ts)
555
+ T_result = function(Ts, TF)
556
+
557
+ # Make sure values and derivatives are equal
558
+ self.assertEqual(C_result.r, T_result(*vals))
559
+ for k in range(len(vals)):
560
+ theano_derivative = function(Ts, T.grad(TF, Ts[k]))(*vals)
561
+ #print C_result.dr_wrt(Cs[k])
562
+ our_derivative = C_result.dr_wrt(Cs[k])[0,0]
563
+ #print theano_derivative, our_derivative
564
+ self.assertEqual(theano_derivative, our_derivative)
565
+
566
+
567
+ def test_vectors(self):
568
+
569
+ try:
570
+ import theano.tensor as T
571
+ from theano import function
572
+ except:
573
+ return
574
+
575
+ for MT in [False, True]:
576
+
577
+ # Set up variables and function
578
+ vals = [np.random.randn(20) for i in range(5)]
579
+ f = lambda a, b, c, d, e : a + (b * c) - d ** e
580
+
581
+ # Set up our objects
582
+ Cs = [ch.Ch(v) for v in vals]
583
+ C_result = f(*Cs)
584
+ C_result.MT = MT
585
+
586
+ # Set up Theano equivalents
587
+ Ts = T.dvectors('T1', 'T2', 'T3', 'T4', 'T5')
588
+ TF = f(*Ts)
589
+ T_result = function(Ts, TF)
590
+
591
+ if False:
592
+ import theano.gradient
593
+ which = 1
594
+ theano_sse = (TF**2.).sum()
595
+ theano_grad = theano.gradient.grad(theano_sse, Ts[which])
596
+ theano_fn = function(Ts, theano_grad)
597
+ print(theano_fn(*vals))
598
+ C_result_grad = ch.SumOfSquares(C_result).dr_wrt(Cs[which])
599
+ print(C_result_grad)
600
+
601
+ # if True:
602
+ # aaa = np.linalg.solve(C_result_grad.T.dot(C_result_grad), C_result_grad.dot(np.zeros(C_result_grad.shape[1])))
603
+ # theano_hes = theano.R_obbb = theano.R_op()
604
+
605
+ import pdb; pdb.set_trace()
606
+
607
+ # Make sure values and derivatives are equal
608
+ np.testing.assert_array_equal(C_result.r, T_result(*vals))
609
+ for k in range(len(vals)):
610
+ theano_derivative = function(Ts, T.jacobian(TF, Ts[k]))(*vals)
611
+ our_derivative = np.array(C_result.dr_wrt(Cs[k]).todense())
612
+ #print theano_derivative, our_derivative
613
+
614
+ # Theano produces has more nans than we do during exponentiation.
615
+ # So we test only on entries where Theano is without NaN's
616
+ without_nans = np.nonzero(np.logical_not(np.isnan(theano_derivative.flatten())))[0]
617
+ np.testing.assert_array_equal(theano_derivative.flatten()[without_nans], our_derivative.flatten()[without_nans])
618
+
619
+
620
+ if __name__ == '__main__':
621
+ unittest.main()
vendor/chumpy/chumpy/test_inner_composition.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+ import unittest
10
+ from .ch import Ch, depends_on
11
+
12
+ class TestInnerComposition(unittest.TestCase):
13
+
14
+ def test_ic(self):
15
+ child = Child(a=Ch(10))
16
+ parent = Parent(child=child, aliased=Ch(50))
17
+
18
+ junk = [parent.aliased_dependency for k in range(3)]
19
+ self.assertTrue(parent.dcount == 1)
20
+ self.assertTrue(parent.ocount == 0)
21
+ self.assertTrue(parent.rcount == 0)
22
+
23
+ junk = [parent.r for k in range(3)]
24
+ self.assertTrue(parent.dcount == 1)
25
+ self.assertTrue(parent.ocount == 1)
26
+ self.assertTrue(parent.rcount == 1)
27
+
28
+ parent.aliased = Ch(20)
29
+ junk = [parent.aliased_dependency for k in range(3)]
30
+ self.assertTrue(parent.dcount == 2)
31
+ self.assertTrue(parent.ocount == 1)
32
+ self.assertTrue(parent.rcount == 1)
33
+
34
+ junk = [parent.r for k in range(3)]
35
+ self.assertTrue(parent.dcount == 2)
36
+ self.assertTrue(parent.ocount == 2)
37
+ self.assertTrue(parent.rcount == 2)
38
+
39
+ class Parent(Ch):
40
+ dterms = ('aliased', 'child')
41
+
42
+ def __init__(self, *args, **kwargs):
43
+ self.dcount = 0
44
+ self.ocount = 0
45
+ self.rcount = 0
46
+
47
+
48
+ def on_changed(self, which):
49
+ assert('aliased' in which and 'child' in which)
50
+ if 'aliased' in which:
51
+ self.ocount += 1
52
+
53
+ @depends_on('aliased')
54
+ def aliased_dependency(self):
55
+ self.dcount += 1
56
+
57
+ @property
58
+ def aliased(self):
59
+ return self.child.a
60
+
61
+ @aliased.setter
62
+ def aliased(self, val):
63
+ self.child.a = val
64
+
65
+ def compute_r(self):
66
+ self.rcount += 1
67
+ return 0
68
+
69
+ def compute_dr_wrt(self, wrt):
70
+ pass
71
+
72
+
73
+ class Child(Ch):
74
+ dterms = ('a',)
75
+
76
+
77
+
78
+ if __name__ == '__main__':
79
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestInnerComposition)
80
+ unittest.TextTestRunner(verbosity=2).run(suite)
vendor/chumpy/chumpy/test_linalg.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+ import numpy as np
10
+ import unittest
11
+
12
+ from .ch import Ch
13
+
14
+
15
+
16
+
17
+ class TestLinalg(unittest.TestCase):
18
+
19
+ def setUp(self):
20
+ np.random.seed(0)
21
+
22
+
23
+ def test_slogdet(self):
24
+ from . import ch
25
+ tmp = ch.random.randn(100).reshape((10,10))
26
+ # print 'chumpy version: ' + str(slogdet(tmp)[1].r)
27
+ # print 'old version:' + str(np.linalg.slogdet(tmp.r)[1])
28
+
29
+ eps = 1e-10
30
+ diff = np.random.rand(100) * eps
31
+ diff_reshaped = diff.reshape((10,10))
32
+ gt = np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1]
33
+ pred = ch.linalg.slogdet(tmp)[1].dr_wrt(tmp).dot(diff)
34
+ #print gt
35
+ #print pred
36
+ diff = gt - pred
37
+
38
+ self.assertTrue(np.max(np.abs(diff)) < 1e-12)
39
+
40
+ sgn_gt = np.linalg.slogdet(tmp.r)[0]
41
+ sgn_pred = ch.linalg.slogdet(tmp)[0]
42
+
43
+ #print sgn_gt
44
+ #print sgn_pred
45
+ diff = sgn_gt - sgn_pred.r
46
+ self.assertTrue(np.max(np.abs(diff)) < 1e-12)
47
+
48
+
49
+ def test_lstsq(self):
50
+ from .linalg import lstsq
51
+
52
+ shapes = ([10, 3], [3, 10])
53
+
54
+ for shape in shapes:
55
+ for b2d in True, False:
56
+ A = (np.random.rand(np.prod(shape))-.5).reshape(shape)
57
+ if b2d:
58
+ b = np.random.randn(shape[0],2)
59
+ else:
60
+ b = np.random.randn(shape[0])
61
+
62
+ x1, residuals1, rank1, s1 = lstsq(A, b)
63
+ x2, residuals2, rank2, s2 = np.linalg.lstsq(A, b)
64
+
65
+ #print x1.r
66
+ #print x2
67
+ #print residuals1.r
68
+ #print residuals2
69
+ self.assertTrue(np.max(np.abs(x1.r-x2)) < 1e-14)
70
+ if len(residuals2) > 0:
71
+ self.assertTrue(np.max(np.abs(residuals1.r-residuals2)) < 1e-14)
72
+
73
+
74
+
75
+
76
+ def test_pinv(self):
77
+ from .linalg import Pinv
78
+
79
+ data = (np.random.rand(12)-.5).reshape((3, 4))
80
+ pc_tall = Pinv(data)
81
+ pc_wide = Pinv(data.T)
82
+
83
+ pn_tall = np.linalg.pinv(data)
84
+ pn_wide = np.linalg.pinv(data.T)
85
+
86
+ tall_correct = np.max(np.abs(pc_tall.r - pn_tall)) < 1e-12
87
+ wide_correct = np.max(np.abs(pc_wide.r - pn_wide)) < 1e-12
88
+ # if not tall_correct or not wide_correct:
89
+ # print tall_correct
90
+ # print wide_correct
91
+ # import pdb; pdb.set_trace()
92
+ self.assertTrue(tall_correct)
93
+ self.assertTrue(wide_correct)
94
+
95
+ return # FIXME. how to test derivs?
96
+
97
+ for pc in [pc_tall, pc_wide]:
98
+
99
+ self.chkd(pc, pc.mtx)
100
+ import pdb; pdb.set_trace()
101
+
102
+
103
+
104
+ def test_svd(self):
105
+ from .linalg import Svd
106
+ eps = 1e-3
107
+ idx = 10
108
+
109
+ data = np.sin(np.arange(300)*100+10).reshape((-1,3))
110
+ data[3,:] = data[3,:]*0+10
111
+ data[:,1] *= 2
112
+ data[:,2] *= 4
113
+ data = data.copy()
114
+ u,s,v = np.linalg.svd(data, full_matrices=False)
115
+ data = Ch(data)
116
+ data2 = data.r.copy()
117
+ data2.ravel()[idx] += eps
118
+ u2,s2,v2 = np.linalg.svd(data2, full_matrices=False)
119
+
120
+
121
+ svdu, svdd, svdv = Svd(x=data)
122
+
123
+ # test singular values
124
+ diff_emp = (s2-s) / eps
125
+ diff_pred = svdd.dr_wrt(data)[:,idx]
126
+ #print diff_emp
127
+ #print diff_pred
128
+ ratio = diff_emp / diff_pred
129
+ #print ratio
130
+ self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-4)
131
+
132
+ # test V
133
+ diff_emp = (v2 - v) / eps
134
+ diff_pred = svdv.dr_wrt(data)[:,idx].reshape(diff_emp.shape)
135
+ ratio = diff_emp / diff_pred
136
+ #print ratio
137
+ self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-2)
138
+
139
+ # test U
140
+ diff_emp = (u2 - u) / eps
141
+ diff_pred = svdu.dr_wrt(data)[:,idx].reshape(diff_emp.shape)
142
+ ratio = diff_emp / diff_pred
143
+ #print ratio
144
+ self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-2)
145
+
146
+
147
+ def test_det(self):
148
+ from .linalg import Det
149
+
150
+ mtx1 = Ch(np.sin(2**np.arange(9)).reshape((3,3)))
151
+ mtx1_det = Det(mtx1)
152
+ dr = mtx1_det.dr_wrt(mtx1)
153
+
154
+ eps = 1e-5
155
+ mtx2 = mtx1.r.copy()
156
+ input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
157
+ mtx2 += input_diff
158
+ mtx2_det = Det(mtx2)
159
+
160
+ output_diff_emp = (np.linalg.det(mtx2) - np.linalg.det(mtx1.r)).ravel()
161
+
162
+ output_diff_pred = Det(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
163
+
164
+ #print output_diff_emp
165
+ #print output_diff_pred
166
+
167
+ self.assertTrue(np.max(np.abs(output_diff_emp - output_diff_pred)) < eps*1e-4)
168
+ self.assertTrue(np.max(np.abs(mtx1_det.r - np.linalg.det(mtx1.r)).ravel()) == 0)
169
+
170
+
171
+
172
+ def test_inv1(self):
173
+ from .linalg import Inv
174
+
175
+ mtx1 = Ch(np.sin(2**np.arange(9)).reshape((3,3)))
176
+ mtx1_inv = Inv(mtx1)
177
+ dr = mtx1_inv.dr_wrt(mtx1)
178
+
179
+ eps = 1e-5
180
+ mtx2 = mtx1.r.copy()
181
+ input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
182
+ mtx2 += input_diff
183
+ mtx2_inv = Inv(mtx2)
184
+
185
+ output_diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1.r)).ravel()
186
+ output_diff_pred = Inv(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
187
+
188
+ #print output_diff_emp
189
+ #print output_diff_pred
190
+
191
+ self.assertTrue(np.max(np.abs(output_diff_emp - output_diff_pred)) < eps*1e-4)
192
+ self.assertTrue(np.max(np.abs(mtx1_inv.r - np.linalg.inv(mtx1.r)).ravel()) == 0)
193
+
194
+ def test_inv2(self):
195
+ from .linalg import Inv
196
+
197
+ eps = 1e-8
198
+ idx = 13
199
+
200
+ mtx1 = np.random.rand(100).reshape((10,10))
201
+ mtx2 = mtx1.copy()
202
+ mtx2.ravel()[idx] += eps
203
+
204
+ diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1)) / eps
205
+
206
+ mtx1 = Ch(mtx1)
207
+ diff_pred = Inv(mtx1).dr_wrt(mtx1)[:,13].reshape(diff_emp.shape)
208
+ #print diff_emp
209
+ #print diff_pred
210
+ #print diff_emp - diff_pred
211
+ self.assertTrue(np.max(np.abs(diff_pred.ravel()-diff_emp.ravel())) < 1e-4)
212
+
213
+ @unittest.skipIf(np.__version__ < '1.8',
214
+ "broadcasting for matrix inverse not supported in numpy < 1.8")
215
+ def test_inv3(self):
216
+ """Test linalg.inv with broadcasting support."""
217
+
218
+ from .linalg import Inv
219
+
220
+ mtx1 = Ch(np.sin(2**np.arange(12)).reshape((3,2,2)))
221
+ mtx1_inv = Inv(mtx1)
222
+ dr = mtx1_inv.dr_wrt(mtx1)
223
+
224
+ eps = 1e-5
225
+ mtx2 = mtx1.r.copy()
226
+ input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
227
+ mtx2 += input_diff
228
+ mtx2_inv = Inv(mtx2)
229
+
230
+ output_diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1.r)).ravel()
231
+ output_diff_pred = Inv(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
232
+
233
+ # print output_diff_emp
234
+ # print output_diff_pred
235
+
236
+ self.assertTrue(np.max(np.abs(output_diff_emp.ravel() - output_diff_pred.ravel())) < eps*1e-3)
237
+ self.assertTrue(np.max(np.abs(mtx1_inv.r - np.linalg.inv(mtx1.r)).ravel()) == 0)
238
+
239
+ def chkd(self, obj, parm, eps=1e-14):
240
+ backed_up = parm.x
241
+
242
+ if True:
243
+ diff = (np.random.rand(parm.size)-.5).reshape(parm.shape)
244
+ else:
245
+ diff = np.zeros(parm.shape)
246
+ diff.ravel()[4] = 2.
247
+
248
+ dr = obj.dr_wrt(parm)
249
+
250
+ parm.x = backed_up - diff*eps
251
+ r_lower = obj.r
252
+
253
+ parm.x = backed_up + diff*eps
254
+ r_upper = obj.r
255
+
256
+ diff_emp = (r_upper - r_lower) / (eps*2.)
257
+ diff_pred = dr.dot(diff.ravel()).reshape(diff_emp.shape)
258
+
259
+ #print diff_emp
260
+ #print diff_pred
261
+ print(diff_emp / diff_pred)
262
+ print(diff_emp - diff_pred)
263
+
264
+ parm.x = backed_up
265
+
266
+
267
+
268
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestLinalg)
269
+
270
+ if __name__ == '__main__':
271
+ unittest.main()
272
+
vendor/chumpy/chumpy/test_optimization.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ """
4
+ Author(s): Matthew Loper
5
+
6
+ See LICENCE.txt for licensing and contact information.
7
+ """
8
+
9
+ import time
10
+ from numpy import *
11
+ import unittest
12
+ from . import ch
13
+ from .optimization import minimize
14
+ from .ch import Ch
15
+ import numpy as np
16
+ from scipy.optimize import rosen, rosen_der
17
+ from .utils import row, col
18
+
19
+
20
+ visualize = False
21
+
22
+
23
+ def Rosen():
24
+
25
+ args = {
26
+ 'x1': Ch(-120.),
27
+ 'x2': Ch(-100.)
28
+ }
29
+ r1 = Ch(lambda x1, x2 : (x2 - x1**2.) * 10., args)
30
+ r2 = Ch(lambda x1 : x1 * -1. + 1, args)
31
+
32
+ func = [r1, r2]
33
+
34
+ return func, [args['x1'], args['x2']]
35
+
36
+ class Madsen(Ch):
37
+ dterms = ('x',)
38
+ def compute_r(self):
39
+ x1 = self.x.r[0]
40
+ x2 = self.x.r[1]
41
+ result = np.array((
42
+ x1**2 + x2**2 + x1 * x2,
43
+ np.sin(x1),
44
+ np.cos(x2)
45
+ ))
46
+ return result
47
+
48
+ def compute_dr_wrt(self, wrt):
49
+ if wrt is not self.x:
50
+ return None
51
+ jac = np.zeros((3,2))
52
+ x1 = self.x.r[0]
53
+ x2 = self.x.r[1]
54
+ jac[0,0] = 2. * x1 + x2
55
+ jac[0,1] = 2. * x2 + x1
56
+
57
+ jac[1,0] = np.cos(x1)
58
+ jac[1,1] = 0
59
+
60
+ jac[2,0] = 0
61
+ jac[2,1] = -np.sin(x2)
62
+
63
+ return jac
64
+
65
+
66
+ def set_and_get_r(self, x_in):
67
+ self.x = Ch(x_in)
68
+ return col(self.r)
69
+
70
+ def set_and_get_dr(self, x_in):
71
+ self.x = Ch(x_in)
72
+ return self.dr_wrt(self.x)
73
+
74
+
75
+
76
+
77
+ class RosenCh(Ch):
78
+ dterms = ('x',)
79
+ def compute_r(self):
80
+
81
+ result = np.array((rosen(self.x.r) ))
82
+
83
+ return result
84
+
85
+ def set_and_get_r(self, x_in):
86
+ self.x = Ch(x_in)
87
+ return col(self.r)
88
+
89
+ def set_and_get_dr(self, x_in):
90
+ self.x = Ch(x_in)
91
+ return self.dr_wrt(self.x).flatten()
92
+
93
+
94
+ def compute_dr_wrt(self, wrt):
95
+ if wrt is self.x:
96
+ if visualize:
97
+ import matplotlib.pyplot as plt
98
+ residuals = np.sum(self.r**2)
99
+ print('------> RESIDUALS %.2e' % (residuals,))
100
+ print('------> CURRENT GUESS %s' % (str(self.x.r),))
101
+ plt.figure(123)
102
+
103
+ if not hasattr(self, 'vs'):
104
+ self.vs = []
105
+ self.xs = []
106
+ self.ys = []
107
+ self.vs.append(residuals)
108
+ self.xs.append(self.x.r[0])
109
+ self.ys.append(self.x.r[1])
110
+ plt.clf();
111
+ plt.subplot(1,2,1)
112
+ plt.plot(self.vs)
113
+ plt.subplot(1,2,2)
114
+ plt.plot(self.xs, self.ys)
115
+ plt.draw()
116
+
117
+
118
+ return row(rosen_der(self.x.r))
119
+
120
+
121
+
122
+ class TestOptimization(unittest.TestCase):
123
+
124
+ def test_dogleg_rosen(self):
125
+ obj, freevars = Rosen()
126
+ minimize(fun=obj, x0=freevars, method='dogleg', options={'maxiter': 337, 'disp': False})
127
+ self.assertTrue(freevars[0].r[0]==1.)
128
+ self.assertTrue(freevars[1].r[0]==1.)
129
+
130
+ def test_dogleg_madsen(self):
131
+ obj = Madsen(x = Ch(np.array((3.,1.))))
132
+ minimize(fun=obj, x0=[obj.x], method='dogleg', options={'maxiter': 34, 'disp': False})
133
+ self.assertTrue(np.sum(obj.r**2)/2 < 0.386599528247)
134
+
135
+ @unittest.skip('negative sign in exponent screws with reverse mode')
136
+ def test_bfgs_rosen(self):
137
+ from .optimization import minimize_bfgs_lsq
138
+ obj, freevars = Rosen()
139
+ minimize_bfgs_lsq(obj=obj, niters=421, verbose=False, free_variables=freevars)
140
+ self.assertTrue(freevars[0].r[0]==1.)
141
+ self.assertTrue(freevars[1].r[0]==1.)
142
+
143
+ def test_bfgs_madsen(self):
144
+ from .ch import SumOfSquares
145
+ import scipy.optimize
146
+ obj = Ch(lambda x : SumOfSquares(Madsen(x = x)) )
147
+
148
+ def errfunc(x):
149
+ obj.x = Ch(x)
150
+ return obj.r
151
+
152
+ def gradfunc(x):
153
+ obj.x = Ch(x)
154
+ return obj.dr_wrt(obj.x).ravel()
155
+
156
+ x0 = np.array((3., 1.))
157
+
158
+ # Optimize with built-in bfgs.
159
+ # Note: with 8 iters, this actually requires 14 gradient evaluations.
160
+ # This can be verified by setting "disp" to 1.
161
+ #tm = time.time()
162
+ x1 = scipy.optimize.fmin_bfgs(errfunc, x0, fprime=gradfunc, maxiter=8, disp=0)
163
+ #print 'forward: took %.es' % (time.time() - tm,)
164
+ self.assertLess(obj.r/2., 0.4)
165
+
166
+ # Optimize with chumpy's minimize (which uses scipy's bfgs).
167
+ obj.x = x0
168
+ minimize(fun=obj, x0=[obj.x], method='bfgs', options={'maxiter': 8, 'disp': False})
169
+ self.assertLess(obj.r/2., 0.4)
170
+
171
+ def test_nested_select(self):
172
+ def beales(x, y):
173
+ e1 = 1.5 - x + x*y
174
+ e2 = 2.25 - x + x*(y**2)
175
+ e3 = 2.625 - x + x*(y**3)
176
+ return {'e1': e1, 'e2': e2, 'e3': e3}
177
+
178
+ x1 = ch.zeros(10)
179
+ y1 = ch.zeros(10)
180
+
181
+ # With a single select this worked
182
+ minimize(beales(x1, y1), x0=[x1[1:4], y1], method='dogleg', options={'disp': False})
183
+
184
+ x2 = ch.zeros(10)
185
+ y2 = ch.zeros(10)
186
+
187
+ # But this used to raise `AttributeError: 'Select' object has no attribute 'x'`
188
+ minimize(beales(x2, y2), x0=[x2[1:8][:3], y2], method='dogleg', options={'disp': False})
189
+ np.testing.assert_array_equal(x1, x2)
190
+ np.testing.assert_array_equal(y1, y2)
191
+
192
+
193
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestOptimization)
194
+
195
+ if __name__ == '__main__':
196
+
197
+ if False: # show rosen
198
+ import matplotlib.pyplot as plt
199
+ visualize = True
200
+ plt.ion()
201
+ unittest.main()
202
+ import pdb; pdb.set_trace()
203
+ else:
204
+ unittest.main()
vendor/chumpy/chumpy/testing.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import ch
2
+ import numpy as np
3
+
4
+ fn1 = 'assert_allclose', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_almost_equal', 'assert_array_almost_equal_nulp', 'assert_array_equal', 'assert_array_less', 'assert_array_max_ulp', 'assert_equal', 'assert_no_warnings', 'assert_string_equal'
5
+ fn2 = 'assert_raises', 'assert_warns'
6
+
7
+ # These are unhandled
8
+ fn3 = 'build_err_msg', 'dec', 'decorate_methods', 'decorators', 'division', 'importall', 'jiffies', 'measure', 'memusage', 'nosetester', 'numpytest', 'print_assert_equal', 'print_function', 'raises', 'rand', 'run_module_suite', 'rundocs', 'runstring', 'test', 'utils', 'verbose'
9
+
10
+ __all__ = fn1 + fn2
11
+
12
+ for rtn in fn1:
13
+ exec('def %s(*args, **kwargs) : return np.testing.%s(np.asarray(args[0]), np.asarray(args[1]), *args[2:], **kwargs)' % (rtn, rtn))
14
+
15
+ for rtn in fn2:
16
+ exec('def %s(*args, **kwargs) : return np.testing.%s(*args, **kwargs)' % (rtn, rtn))
17
+
18
+
19
+
20
+ if __name__ == '__main__':
21
+ main()
vendor/chumpy/chumpy/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author(s): Matthew Loper
3
+
4
+ See LICENCE.txt for licensing and contact information.
5
+ """
6
+ import scipy.sparse as sp
7
+ import numpy as np
8
+
9
+ def row(A):
10
+ return A.reshape((1, -1))
11
+
12
+
13
+ def col(A):
14
+ return A.reshape((-1, 1))
15
+
16
+ class timer(object):
17
+ def time(self):
18
+ import time
19
+ return time.time()
20
+ def __init__(self):
21
+ self._elapsed = 0
22
+ self._start = self.time()
23
+ def __call__(self):
24
+ if self._start is not None:
25
+ return self._elapsed + self.time() - self._start
26
+ else:
27
+ return self._elapsed
28
+ def pause(self):
29
+ assert self._start is not None
30
+ self._elapsed += self.time() - self._start
31
+ self._start = None
32
+ def resume(self):
33
+ assert self._start is None
34
+ self._start = self.time()
35
+
36
+ def dfs_do_func_on_graph(node, func, *args, **kwargs):
37
+ '''
38
+ invoke func on each node of the dr graph
39
+ '''
40
+ for _node in node.tree_iterator():
41
+ func(_node, *args, **kwargs)
42
+
43
+
44
+ def sparse_is_desireable(lhs, rhs):
45
+ '''
46
+ Examines a pair of matrices and determines if the result of their multiplication should be sparse or not.
47
+ '''
48
+ return False
49
+ if len(lhs.shape) == 1:
50
+ return False
51
+ else:
52
+ lhs_rows, lhs_cols = lhs.shape
53
+
54
+ if len(rhs.shape) == 1:
55
+ rhs_rows = 1
56
+ rhs_cols = rhs.size
57
+ else:
58
+ rhs_rows, rhs_cols = rhs.shape
59
+
60
+ result_size = lhs_rows * rhs_cols
61
+
62
+ if sp.issparse(lhs) and sp.issparse(rhs):
63
+ return True
64
+ elif sp.issparse(lhs):
65
+ lhs_zero_rows = lhs_rows - np.unique(lhs.nonzero()[0]).size
66
+ rhs_zero_cols = np.all(rhs==0, axis=0).sum()
67
+
68
+ elif sp.issparse(rhs):
69
+ lhs_zero_rows = np.all(lhs==0, axis=1).sum()
70
+ rhs_zero_cols = rhs_cols- np.unique(rhs.nonzero()[1]).size
71
+ else:
72
+ lhs_zero_rows = np.all(lhs==0, axis=1).sum()
73
+ rhs_zero_cols = np.all(rhs==0, axis=0).sum()
74
+
75
+ num_zeros = lhs_zero_rows * rhs_cols + rhs_zero_cols * lhs_rows - lhs_zero_rows * rhs_zero_cols
76
+
77
+ # A sparse matrix uses roughly 16 bytes per nonzero element (8 + 2 4-byte inds), while a dense matrix uses 8 bytes per element. So the break even point for sparsity is 50% nonzero. But in practice, it seems to be that the compression in a csc or csr matrix gets us break even at ~65% nonzero, which lets us say 50% is a conservative, worst cases cutoff.
78
+ return (float(num_zeros) / float(size)) >= 0.5
79
+
80
+
81
+ def convert_inputs_to_sparse_if_necessary(lhs, rhs):
82
+ '''
83
+ This function checks to see if a sparse output is desireable given the inputs and if so, casts the inputs to sparse in order to make it so.
84
+ '''
85
+ if not sp.issparse(lhs) or not sp.issparse(rhs):
86
+ if sparse_is_desireable(lhs, rhs):
87
+ if not sp.issparse(lhs):
88
+ lhs = sp.csc_matrix(lhs)
89
+ #print "converting lhs into sparse matrix"
90
+ if not sp.issparse(rhs):
91
+ rhs = sp.csc_matrix(rhs)
92
+ #print "converting rhs into sparse matrix"
93
+ return lhs, rhs
vendor/chumpy/chumpy/version.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version = '0.71'
2
+ short_version = version
3
+ full_version = version
vendor/chumpy/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy>=1.8.1
2
+ scipy>=0.13.0
3
+ six>=1.11.0
vendor/chumpy/setup.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author(s): Matthew Loper
3
+
4
+ See LICENCE.txt for licensing and contact information.
5
+ """
6
+
7
+ from distutils.core import setup
8
+ from runpy import run_path
9
+
10
+ def get_version():
11
+ namespace = run_path('chumpy/version.py')
12
+ return namespace['version']
13
+
14
+ setup(name='chumpy',
15
+ version=get_version(),
16
+ packages = ['chumpy'],
17
+ author='Matthew Loper',
18
+ author_email='matt.loper@gmail.com',
19
+ url='https://github.com/mattloper/chumpy',
20
+ description='chumpy',
21
+ license='MIT',
22
+ install_requires=['numpy', 'scipy', 'matplotlib'],
23
+
24
+ classifiers=[
25
+ 'Development Status :: 4 - Beta',
26
+ 'Intended Audience :: Science/Research',
27
+ 'Topic :: Scientific/Engineering :: Mathematics',
28
+ 'License :: OSI Approved :: MIT License',
29
+ 'Programming Language :: Python :: 2',
30
+ 'Programming Language :: Python :: 2.7',
31
+ 'Programming Language :: Python :: 3',
32
+ 'Operating System :: MacOS :: MacOS X',
33
+ 'Operating System :: POSIX :: Linux'
34
+ ],
35
+ )