Spaces:
Running on Zero
Running on Zero
Fix chumpy build isolation: vendor with patched setup.py, revert to Gradio SDK
Browse fileschumpy's setup.py does `from pip._internal.req import parse_requirements`
which fails in pip's default isolated build env. Vendored locally with
setup.py patched to use a hardcoded install_requires instead.
Reverted from docker SDK back to gradio SDK (ZeroGPU requires Gradio SDK).
- Dockerfile +0 -35
- README.md +3 -2
- requirements.txt +3 -4
- vendor/chumpy/.circleci/config.yml +56 -0
- vendor/chumpy/.gitignore +142 -0
- vendor/chumpy/LICENSE.txt +22 -0
- vendor/chumpy/MANIFEST.in +3 -0
- vendor/chumpy/Makefile +18 -0
- vendor/chumpy/README.md +60 -0
- vendor/chumpy/chumpy/__init__.py +117 -0
- vendor/chumpy/chumpy/api_compatibility.py +534 -0
- vendor/chumpy/chumpy/ch.py +1367 -0
- vendor/chumpy/chumpy/ch_ops.py +814 -0
- vendor/chumpy/chumpy/ch_random.py +32 -0
- vendor/chumpy/chumpy/extras.py +72 -0
- vendor/chumpy/chumpy/linalg.py +306 -0
- vendor/chumpy/chumpy/logic.py +39 -0
- vendor/chumpy/chumpy/monitor.py +149 -0
- vendor/chumpy/chumpy/np_tensordot.py +228 -0
- vendor/chumpy/chumpy/optimization.py +161 -0
- vendor/chumpy/chumpy/optimization_internal.py +455 -0
- vendor/chumpy/chumpy/optional_test_performance.py +183 -0
- vendor/chumpy/chumpy/reordering.py +454 -0
- vendor/chumpy/chumpy/test_ch.py +621 -0
- vendor/chumpy/chumpy/test_inner_composition.py +80 -0
- vendor/chumpy/chumpy/test_linalg.py +272 -0
- vendor/chumpy/chumpy/test_optimization.py +204 -0
- vendor/chumpy/chumpy/testing.py +21 -0
- vendor/chumpy/chumpy/utils.py +93 -0
- vendor/chumpy/chumpy/version.py +3 -0
- vendor/chumpy/requirements.txt +3 -0
- vendor/chumpy/setup.py +35 -0
Dockerfile
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
FROM python:3.10-slim
|
| 2 |
-
|
| 3 |
-
# System deps
|
| 4 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 5 |
-
git wget curl build-essential cmake ninja-build pkg-config \
|
| 6 |
-
libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev ffmpeg \
|
| 7 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
-
|
| 9 |
-
# HF user setup
|
| 10 |
-
RUN useradd -m -u 1000 user
|
| 11 |
-
USER user
|
| 12 |
-
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 13 |
-
WORKDIR $HOME/app
|
| 14 |
-
|
| 15 |
-
# Upgrade pip first
|
| 16 |
-
RUN pip install --user --upgrade pip setuptools wheel
|
| 17 |
-
|
| 18 |
-
# chumpy must be installed with --no-build-isolation BEFORE everything else
|
| 19 |
-
# (its setup.py does `import pip` which fails in pip's default isolated build env)
|
| 20 |
-
RUN pip install --user --no-build-isolation \
|
| 21 |
-
"chumpy @ git+https://github.com/mattloper/chumpy.git@580566eafc9ac68b2614b64d6f7aaa84eebb70da"
|
| 22 |
-
|
| 23 |
-
# Copy app files
|
| 24 |
-
COPY --chown=user . $HOME/app
|
| 25 |
-
|
| 26 |
-
# Install remaining requirements (chumpy already satisfied above)
|
| 27 |
-
RUN pip install --user --no-cache-dir -r requirements.txt \
|
| 28 |
-
"torch<=2.9.1" \
|
| 29 |
-
"gradio[oauth,mcp]==6.11.0" \
|
| 30 |
-
"uvicorn>=0.14.0" \
|
| 31 |
-
"websockets>=10.4" \
|
| 32 |
-
"spaces==0.48.1"
|
| 33 |
-
|
| 34 |
-
EXPOSE 7860
|
| 35 |
-
CMD ["python", "app.py"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -3,8 +3,9 @@ title: Image2Model
|
|
| 3 |
emoji: 🎭
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: blue
|
| 6 |
-
sdk:
|
| 7 |
-
|
|
|
|
| 8 |
pinned: false
|
| 9 |
license: apache-2.0
|
| 10 |
hardware: zero-a10g
|
|
|
|
| 3 |
emoji: 🎭
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "6.11.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
hardware: zero-a10g
|
requirements.txt
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
-
# HuggingFace ZeroGPU Space —
|
| 2 |
-
# chumpy
|
| 3 |
-
# (its setup.py does `import pip` which breaks in modern pip isolated builds)
|
| 4 |
spaces
|
| 5 |
|
| 6 |
# Git-pinned installs
|
| 7 |
hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c
|
| 8 |
clip @ git+https://github.com/openai/CLIP.git@d05afc436d78f1c48dc0dbf8e5980a9d471f35f6
|
| 9 |
mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be
|
| 10 |
-
chumpy @
|
| 11 |
skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95
|
| 12 |
nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae
|
| 13 |
diso @ git+https://github.com/SarahWeiii/diso.git@9792ad928ccb09bdec938779651ee03e395758a6
|
|
|
|
| 1 |
+
# HuggingFace ZeroGPU Space — Gradio SDK
|
| 2 |
+
# chumpy vendored locally with patched setup.py (original does `import pip` which breaks isolated builds)
|
|
|
|
| 3 |
spaces
|
| 4 |
|
| 5 |
# Git-pinned installs
|
| 6 |
hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c
|
| 7 |
clip @ git+https://github.com/openai/CLIP.git@d05afc436d78f1c48dc0dbf8e5980a9d471f35f6
|
| 8 |
mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be
|
| 9 |
+
chumpy @ ./vendor/chumpy
|
| 10 |
skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95
|
| 11 |
nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae
|
| 12 |
diso @ git+https://github.com/SarahWeiii/diso.git@9792ad928ccb09bdec938779651ee03e395758a6
|
vendor/chumpy/.circleci/config.yml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: 2.1
|
| 2 |
+
|
| 3 |
+
orbs:
|
| 4 |
+
python: circleci/python@2.1.1 # optional, for helpers
|
| 5 |
+
|
| 6 |
+
jobs:
|
| 7 |
+
python3:
|
| 8 |
+
docker:
|
| 9 |
+
- image: cimg/python:3.12
|
| 10 |
+
steps:
|
| 11 |
+
- checkout
|
| 12 |
+
- run:
|
| 13 |
+
name: Install system deps
|
| 14 |
+
command: |
|
| 15 |
+
sudo apt-get update
|
| 16 |
+
sudo apt-get install -y --no-install-recommends gfortran liblapack-dev
|
| 17 |
+
- restore_cache:
|
| 18 |
+
keys:
|
| 19 |
+
- v1-pip-{{ arch }}-{{ .Branch }}-{{ checksum "requirements.txt" }}
|
| 20 |
+
- v1-pip-{{ arch }}-
|
| 21 |
+
- run:
|
| 22 |
+
name: Install python deps
|
| 23 |
+
command: |
|
| 24 |
+
python -m venv venv
|
| 25 |
+
. venv/bin/activate
|
| 26 |
+
pip install -U pip
|
| 27 |
+
set -o pipefail; pip install -r requirements.txt | cat
|
| 28 |
+
- save_cache:
|
| 29 |
+
key: v1-pip-{{ arch }}-{{ .Branch }}-{{ checksum "requirements.txt" }}
|
| 30 |
+
paths:
|
| 31 |
+
- ~/.cache/pip
|
| 32 |
+
- run:
|
| 33 |
+
name: Show versions
|
| 34 |
+
command: |
|
| 35 |
+
. venv/bin/activate
|
| 36 |
+
pip freeze
|
| 37 |
+
- run:
|
| 38 |
+
name: Run tests
|
| 39 |
+
command: |
|
| 40 |
+
. venv/bin/activate
|
| 41 |
+
make test
|
| 42 |
+
|
| 43 |
+
workflows:
|
| 44 |
+
version: 2
|
| 45 |
+
on-commit:
|
| 46 |
+
jobs:
|
| 47 |
+
- python3
|
| 48 |
+
daily:
|
| 49 |
+
triggers:
|
| 50 |
+
- schedule:
|
| 51 |
+
cron: "0 17 * * *"
|
| 52 |
+
filters:
|
| 53 |
+
branches:
|
| 54 |
+
only: master
|
| 55 |
+
jobs:
|
| 56 |
+
- python3
|
vendor/chumpy/.gitignore
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Created by https://www.gitignore.io/api/osx,python
|
| 3 |
+
|
| 4 |
+
### OSX ###
|
| 5 |
+
# General
|
| 6 |
+
.DS_Store
|
| 7 |
+
.AppleDouble
|
| 8 |
+
.LSOverride
|
| 9 |
+
|
| 10 |
+
# Icon must end with two \r
|
| 11 |
+
Icon
|
| 12 |
+
|
| 13 |
+
# Thumbnails
|
| 14 |
+
._*
|
| 15 |
+
|
| 16 |
+
# Files that might appear in the root of a volume
|
| 17 |
+
.DocumentRevisions-V100
|
| 18 |
+
.fseventsd
|
| 19 |
+
.Spotlight-V100
|
| 20 |
+
.TemporaryItems
|
| 21 |
+
.Trashes
|
| 22 |
+
.VolumeIcon.icns
|
| 23 |
+
.com.apple.timemachine.donotpresent
|
| 24 |
+
|
| 25 |
+
# Directories potentially created on remote AFP share
|
| 26 |
+
.AppleDB
|
| 27 |
+
.AppleDesktop
|
| 28 |
+
Network Trash Folder
|
| 29 |
+
Temporary Items
|
| 30 |
+
.apdisk
|
| 31 |
+
|
| 32 |
+
### Python ###
|
| 33 |
+
# Byte-compiled / optimized / DLL files
|
| 34 |
+
__pycache__/
|
| 35 |
+
*.py[cod]
|
| 36 |
+
*$py.class
|
| 37 |
+
|
| 38 |
+
# C extensions
|
| 39 |
+
*.so
|
| 40 |
+
|
| 41 |
+
# Distribution / packaging
|
| 42 |
+
.Python
|
| 43 |
+
build/
|
| 44 |
+
develop-eggs/
|
| 45 |
+
dist/
|
| 46 |
+
downloads/
|
| 47 |
+
eggs/
|
| 48 |
+
.eggs/
|
| 49 |
+
lib/
|
| 50 |
+
lib64/
|
| 51 |
+
parts/
|
| 52 |
+
sdist/
|
| 53 |
+
var/
|
| 54 |
+
wheels/
|
| 55 |
+
*.egg-info/
|
| 56 |
+
.installed.cfg
|
| 57 |
+
*.egg
|
| 58 |
+
MANIFEST
|
| 59 |
+
|
| 60 |
+
# PyInstaller
|
| 61 |
+
# Usually these files are written by a python script from a template
|
| 62 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 63 |
+
*.manifest
|
| 64 |
+
*.spec
|
| 65 |
+
|
| 66 |
+
# Installer logs
|
| 67 |
+
pip-log.txt
|
| 68 |
+
pip-delete-this-directory.txt
|
| 69 |
+
|
| 70 |
+
# Unit test / coverage reports
|
| 71 |
+
htmlcov/
|
| 72 |
+
.tox/
|
| 73 |
+
.coverage
|
| 74 |
+
.coverage.*
|
| 75 |
+
.cache
|
| 76 |
+
nosetests.xml
|
| 77 |
+
coverage.xml
|
| 78 |
+
*.cover
|
| 79 |
+
.hypothesis/
|
| 80 |
+
.pytest_cache/
|
| 81 |
+
|
| 82 |
+
# Translations
|
| 83 |
+
*.mo
|
| 84 |
+
*.pot
|
| 85 |
+
|
| 86 |
+
# Django stuff:
|
| 87 |
+
*.log
|
| 88 |
+
local_settings.py
|
| 89 |
+
db.sqlite3
|
| 90 |
+
|
| 91 |
+
# Flask stuff:
|
| 92 |
+
instance/
|
| 93 |
+
.webassets-cache
|
| 94 |
+
|
| 95 |
+
# Scrapy stuff:
|
| 96 |
+
.scrapy
|
| 97 |
+
|
| 98 |
+
# Sphinx documentation
|
| 99 |
+
docs/_build/
|
| 100 |
+
|
| 101 |
+
# PyBuilder
|
| 102 |
+
target/
|
| 103 |
+
|
| 104 |
+
# Jupyter Notebook
|
| 105 |
+
.ipynb_checkpoints
|
| 106 |
+
|
| 107 |
+
# pyenv
|
| 108 |
+
.python-version
|
| 109 |
+
|
| 110 |
+
# celery beat schedule file
|
| 111 |
+
celerybeat-schedule
|
| 112 |
+
|
| 113 |
+
# SageMath parsed files
|
| 114 |
+
*.sage.py
|
| 115 |
+
|
| 116 |
+
# Environments
|
| 117 |
+
.env
|
| 118 |
+
.venv
|
| 119 |
+
env/
|
| 120 |
+
venv/
|
| 121 |
+
ENV/
|
| 122 |
+
env.bak/
|
| 123 |
+
venv.bak/
|
| 124 |
+
|
| 125 |
+
# Spyder project settings
|
| 126 |
+
.spyderproject
|
| 127 |
+
.spyproject
|
| 128 |
+
|
| 129 |
+
# Rope project settings
|
| 130 |
+
.ropeproject
|
| 131 |
+
|
| 132 |
+
# mkdocs documentation
|
| 133 |
+
/site
|
| 134 |
+
|
| 135 |
+
# mypy
|
| 136 |
+
.mypy_cache/
|
| 137 |
+
|
| 138 |
+
### Python Patch ###
|
| 139 |
+
.venv/
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# End of https://www.gitignore.io/api/osx,python
|
vendor/chumpy/LICENSE.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The MIT License (MIT)
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2014 Max-Planck-Gesellschaft
|
| 4 |
+
Copyright (c) 2014 Matthew Loper
|
| 5 |
+
|
| 6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 7 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 8 |
+
in the Software without restriction, including without limitation the rights
|
| 9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 10 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 11 |
+
furnished to do so, subject to the following conditions:
|
| 12 |
+
|
| 13 |
+
The above copyright notice and this permission notice shall be included in
|
| 14 |
+
all copies or substantial portions of the Software.
|
| 15 |
+
|
| 16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 22 |
+
THE SOFTWARE.
|
vendor/chumpy/MANIFEST.in
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
global-include . *.py *.c *.h Makefile *.pyx requirements.txt
|
| 2 |
+
global-exclude chumpy/optional_test_performance.py
|
| 3 |
+
prune dist
|
vendor/chumpy/Makefile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
all:
|
| 2 |
+
|
| 3 |
+
upload:
|
| 4 |
+
rm -r dist
|
| 5 |
+
python setup.py sdist
|
| 6 |
+
twine upload dist/*
|
| 7 |
+
|
| 8 |
+
test:
|
| 9 |
+
# For some reason the import changes for Python 3 caused the Python 2 test
|
| 10 |
+
# loader to give up without loading any tests. So we discover them ourselves.
|
| 11 |
+
# python -m unittest
|
| 12 |
+
find chumpy -name 'test_*.py' | sed -e 's/\.py$$//' -e 's/\//./' | xargs python -m unittest
|
| 13 |
+
|
| 14 |
+
coverage: clean qcov
|
| 15 |
+
qcov: all
|
| 16 |
+
env LD_PRELOAD=$(PRELOADED) coverage run --source=. -m unittest discover -s .
|
| 17 |
+
coverage html
|
| 18 |
+
coverage report -m
|
vendor/chumpy/README.md
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
chumpy
|
| 2 |
+
======
|
| 3 |
+
|
| 4 |
+
[][pypi]
|
| 5 |
+
[][pypi]
|
| 6 |
+
[][pypi]
|
| 7 |
+
[][circle]
|
| 8 |
+
|
| 9 |
+
Autodifferentiation tool for Python.
|
| 10 |
+
|
| 11 |
+
[circle]: https://circleci.com/gh/mattloper/chumpy
|
| 12 |
+
[pypi]: https://pypi.org/project/chumpy/
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Installation
|
| 16 |
+
------------
|
| 17 |
+
|
| 18 |
+
Install the fork:
|
| 19 |
+
|
| 20 |
+
```sh
|
| 21 |
+
pip install chumpy
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Import it:
|
| 25 |
+
|
| 26 |
+
```py
|
| 27 |
+
import chumpy as ch
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Overview
|
| 31 |
+
--------
|
| 32 |
+
|
| 33 |
+
Chumpy is a Python-based framework designed to handle the **auto-differentiation** problem,
|
| 34 |
+
which is to evaluate an expression and its derivatives with respect to its inputs, by use of the chain rule.
|
| 35 |
+
|
| 36 |
+
Chumpy is intended to make construction and local
|
| 37 |
+
minimization of objectives easier.
|
| 38 |
+
|
| 39 |
+
Specifically, it provides:
|
| 40 |
+
|
| 41 |
+
- Easy problem construction by using Numpy’s application interface
|
| 42 |
+
- Easy access to derivatives via auto differentiation
|
| 43 |
+
- Easy local optimization methods (12 of them: most of which use the derivatives)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
Usage
|
| 47 |
+
-----
|
| 48 |
+
|
| 49 |
+
Chumpy comes with its own demos, which can be seen by typing the following:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
import chumpy
|
| 53 |
+
chumpy.demo() # prints out a list of possible demos
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
License
|
| 58 |
+
-------
|
| 59 |
+
|
| 60 |
+
This project is licensed under the MIT License.
|
vendor/chumpy/chumpy/__init__.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .ch import *
|
| 2 |
+
from .logic import *
|
| 3 |
+
|
| 4 |
+
from .optimization import minimize
|
| 5 |
+
from . import extras
|
| 6 |
+
from . import testing
|
| 7 |
+
from .version import version as __version__
|
| 8 |
+
|
| 9 |
+
from .version import version as __version__
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test():
|
| 13 |
+
from os.path import split
|
| 14 |
+
import unittest
|
| 15 |
+
test_loader= unittest.TestLoader()
|
| 16 |
+
test_loader = test_loader.discover(split(__file__)[0])
|
| 17 |
+
test_runner = unittest.TextTestRunner()
|
| 18 |
+
test_runner.run( test_loader )
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
demos = {}
|
| 22 |
+
|
| 23 |
+
demos['scalar'] = """
|
| 24 |
+
import chumpy as ch
|
| 25 |
+
|
| 26 |
+
[x1, x2, x3] = ch.array(10), ch.array(20), ch.array(30)
|
| 27 |
+
result = x1+x2+x3
|
| 28 |
+
print result # prints [ 60.]
|
| 29 |
+
print result.dr_wrt(x1) # prints 1
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
demos['show_tree'] = """
|
| 33 |
+
import chumpy as ch
|
| 34 |
+
|
| 35 |
+
[x1, x2, x3] = ch.array(10), ch.array(20), ch.array(30)
|
| 36 |
+
for i in range(3): x2 = x1 + x2 + x3
|
| 37 |
+
|
| 38 |
+
x2.dr_wrt(x1) # pull cache
|
| 39 |
+
x2.dr_wrt(x3) # pull cache
|
| 40 |
+
x1.label='x1' # for clarity in show_tree()
|
| 41 |
+
x2.label='x2' # for clarity in show_tree()
|
| 42 |
+
x3.label='x3' # for clarity in show_tree()
|
| 43 |
+
x2.show_tree(cachelim=1e-4) # in MB
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
demos['matrix'] = """
|
| 47 |
+
import chumpy as ch
|
| 48 |
+
|
| 49 |
+
x1, x2, x3, x4 = ch.eye(10), ch.array(1), ch.array(5), ch.array(10)
|
| 50 |
+
y = x1*(x2-x3)+x4
|
| 51 |
+
print y
|
| 52 |
+
print y.dr_wrt(x2)
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
demos['linalg'] = """
|
| 56 |
+
import chumpy as ch
|
| 57 |
+
|
| 58 |
+
m = [ch.random.randn(100).reshape((10,10)) for i in range(3)]
|
| 59 |
+
y = m[0].dot(m[1]).dot(ch.linalg.inv(m[2])) * ch.linalg.det(m[0])
|
| 60 |
+
print y.shape
|
| 61 |
+
print y.dr_wrt(m[0]).shape
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
demos['inheritance'] = """
|
| 65 |
+
import chumpy as ch
|
| 66 |
+
import numpy as np
|
| 67 |
+
|
| 68 |
+
class Sin(ch.Ch):
|
| 69 |
+
|
| 70 |
+
dterms = ('x',)
|
| 71 |
+
|
| 72 |
+
def compute_r(self):
|
| 73 |
+
return np.sin(self.x.r)
|
| 74 |
+
|
| 75 |
+
def compute_dr_wrt(self, wrt):
|
| 76 |
+
import scipy.sparse
|
| 77 |
+
if wrt is self.x:
|
| 78 |
+
result = np.cos(self.x.r)
|
| 79 |
+
return scipy.sparse.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
|
| 80 |
+
|
| 81 |
+
x1 = Ch([10,20,30])
|
| 82 |
+
result = Sin(x1) # or "result = Sin(x=x1)"
|
| 83 |
+
print result.r
|
| 84 |
+
print result.dr_wrt(x1)
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
demos['optimization'] = """
|
| 88 |
+
import chumpy as ch
|
| 89 |
+
|
| 90 |
+
x = ch.zeros(10)
|
| 91 |
+
y = ch.zeros(10)
|
| 92 |
+
|
| 93 |
+
# Beale's function
|
| 94 |
+
e1 = 1.5 - x + x*y
|
| 95 |
+
e2 = 2.25 - x + x*(y**2)
|
| 96 |
+
e3 = 2.625 - x + x*(y**3)
|
| 97 |
+
|
| 98 |
+
objective = {'e1': e1, 'e2': e2, 'e3': e3}
|
| 99 |
+
ch.minimize(objective, x0=[x,y], method='dogleg')
|
| 100 |
+
print x # should be all 3.0
|
| 101 |
+
print y # should be all 0.5
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def demo(which=None):
|
| 108 |
+
if which not in demos:
|
| 109 |
+
print('Please indicate which demo you want, as follows:')
|
| 110 |
+
for key in demos:
|
| 111 |
+
print("\tdemo('%s')" % (key,))
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
print('- - - - - - - - - - - <CODE> - - - - - - - - - - - -')
|
| 115 |
+
print(demos[which])
|
| 116 |
+
print('- - - - - - - - - - - </CODE> - - - - - - - - - - - -\n')
|
| 117 |
+
exec('global np\n' + demos[which], globals(), locals())
|
vendor/chumpy/chumpy/api_compatibility.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author(s): Matthew Loper
|
| 3 |
+
|
| 4 |
+
See LICENCE.txt for licensing and contact information.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from . import ch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from os.path import join, split
|
| 11 |
+
from six import StringIO
|
| 12 |
+
import numpy
|
| 13 |
+
import chumpy
|
| 14 |
+
from six.moves import cPickle as pickle
|
| 15 |
+
|
| 16 |
+
src = ''
|
| 17 |
+
num_passed = 0
|
| 18 |
+
num_not_passed = 0
|
| 19 |
+
which_passed = []
|
| 20 |
+
|
| 21 |
+
def r(fn_name, args_req, args_opt, nplib=numpy, chlib=chumpy):
|
| 22 |
+
global num_passed, num_not_passed
|
| 23 |
+
result = [None, None]
|
| 24 |
+
|
| 25 |
+
for lib in [nplib, chlib]:
|
| 26 |
+
|
| 27 |
+
# if fn_name is 'svd' and lib is chlib:
|
| 28 |
+
# import pdb; pdb.set_trace()
|
| 29 |
+
if lib is nplib:
|
| 30 |
+
fn = getattr(lib, fn_name)
|
| 31 |
+
else:
|
| 32 |
+
try:
|
| 33 |
+
fn = getattr(lib, fn_name)
|
| 34 |
+
except AttributeError:
|
| 35 |
+
result[0] = 'missing'
|
| 36 |
+
result[1] = 'missing'
|
| 37 |
+
num_not_passed += 1
|
| 38 |
+
continue
|
| 39 |
+
try:
|
| 40 |
+
if isinstance(args_req, dict):
|
| 41 |
+
_ = fn(**args_req)
|
| 42 |
+
else:
|
| 43 |
+
_ = fn(*args_req)
|
| 44 |
+
if lib is chlib:
|
| 45 |
+
result[0] = 'passed'
|
| 46 |
+
num_passed += 1
|
| 47 |
+
global which_passed
|
| 48 |
+
which_passed.append(fn_name)
|
| 49 |
+
|
| 50 |
+
if hasattr(_, 'dterms'):
|
| 51 |
+
try:
|
| 52 |
+
_.r
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
pickle.dumps(_)
|
| 56 |
+
except:
|
| 57 |
+
result[0] += ' (but unpickleable!)'
|
| 58 |
+
except:
|
| 59 |
+
import pdb; pdb.set_trace()
|
| 60 |
+
result[0] += '(but cant get result!)'
|
| 61 |
+
except Exception as e:
|
| 62 |
+
if e is TypeError:
|
| 63 |
+
import pdb; pdb.set_trace()
|
| 64 |
+
if lib is nplib:
|
| 65 |
+
import pdb; pdb.set_trace()
|
| 66 |
+
else:
|
| 67 |
+
num_not_passed += 1
|
| 68 |
+
# if fn_name == 'rot90':
|
| 69 |
+
# import pdb; pdb.set_trace()
|
| 70 |
+
result[0] = e.__class__.__name__
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
if isinstance(args_req, dict):
|
| 74 |
+
fn(**dict(list(args_req.items()) + list(args_opt.items())))
|
| 75 |
+
else:
|
| 76 |
+
fn(*args_req, **args_opt)
|
| 77 |
+
if lib is chlib:
|
| 78 |
+
result[1] = 'passed'
|
| 79 |
+
except Exception as e:
|
| 80 |
+
if e is TypeError:
|
| 81 |
+
import pdb; pdb.set_trace()
|
| 82 |
+
result[1] = e.__class__.__name__
|
| 83 |
+
|
| 84 |
+
# print '%s: %s, %s' % (fn_name, result[0], result[1])
|
| 85 |
+
|
| 86 |
+
append(fn_name, result[0], result[1])
|
| 87 |
+
|
| 88 |
+
def make_row(a, b, c, b_color, c_color):
|
| 89 |
+
global src
|
| 90 |
+
src += '<tr><td>%s</td><td style="background-color:%s">%s</td><td style="background-color:%s">%s</td></tr>' % (a,b_color, b,c_color, c)
|
| 91 |
+
|
| 92 |
+
def append(a, b, c):
|
| 93 |
+
global src
|
| 94 |
+
b_color = 'white'
|
| 95 |
+
c_color = 'white'
|
| 96 |
+
|
| 97 |
+
b = b.replace('NotImplementedError', 'not yet implemented')
|
| 98 |
+
c = c.replace('NotImplementedError', 'not yet implemented')
|
| 99 |
+
b = b.replace('WontImplement', "won't implement")
|
| 100 |
+
c = c.replace('WontImplement', "won't implement")
|
| 101 |
+
lookup = {
|
| 102 |
+
'passed': 'lightgreen',
|
| 103 |
+
"won't implement": 'lightgray',
|
| 104 |
+
'untested': 'lightyellow',
|
| 105 |
+
'not yet implemented': 'pink'
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
b_color = lookup[b] if b in lookup else 'white'
|
| 109 |
+
c_color = lookup[c] if c in lookup else 'white'
|
| 110 |
+
|
| 111 |
+
print('%s: %s, %s' % (a,b,c))
|
| 112 |
+
make_row(a, b, c, b_color, c_color)
|
| 113 |
+
|
| 114 |
+
def m(s):
|
| 115 |
+
append(s, 'unknown', 'unknown')
|
| 116 |
+
global num_not_passed
|
| 117 |
+
num_not_passed += 1
|
| 118 |
+
|
| 119 |
+
def hd3(s):
|
| 120 |
+
global src
|
| 121 |
+
src += '<tr><td colspan=3><h3 style="margin-bottom:0;">%s</h3></td></tr>' % (s,)
|
| 122 |
+
|
| 123 |
+
def hd2(s):
|
| 124 |
+
global src
|
| 125 |
+
src += '</table><br/><br/><table border=1>'
|
| 126 |
+
src += '<tr><td colspan=3 style="background-color:black;color:white"><h2 style="margin-bottom:0;">%s</h2></td></tr>' % (s,)
|
| 127 |
+
|
| 128 |
+
def main():
|
| 129 |
+
|
| 130 |
+
#sample_array
|
| 131 |
+
|
| 132 |
+
###############################
|
| 133 |
+
hd2('Array Creation Routines')
|
| 134 |
+
|
| 135 |
+
hd3('Ones and zeros')
|
| 136 |
+
|
| 137 |
+
r('empty', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
|
| 138 |
+
r('empty_like', {'prototype': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
|
| 139 |
+
r('eye', {'N': 10}, {'M': 5, 'k': 0, 'dtype': np.float64})
|
| 140 |
+
r('identity', {'n': 10}, {'dtype': np.float64})
|
| 141 |
+
r('ones', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
|
| 142 |
+
r('ones_like', {'a': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
|
| 143 |
+
r('zeros', {'shape': (2,4,2)}, {'dtype': np.uint8, 'order': 'C'})
|
| 144 |
+
r('zeros_like', {'a': np.empty((2,4,2))}, {'dtype': np.float64, 'order': 'C'})
|
| 145 |
+
|
| 146 |
+
hd3('From existing data')
|
| 147 |
+
r('array', {'object': [1,2,3]}, {'dtype': np.float64, 'order': 'C', 'subok': False, 'ndmin': 2})
|
| 148 |
+
r('asarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
|
| 149 |
+
r('asanyarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
|
| 150 |
+
r('ascontiguousarray', {'a': np.array([1,2,3])}, {'dtype': np.float64})
|
| 151 |
+
r('asmatrix', {'data': np.array([1,2,3])}, {'dtype': np.float64})
|
| 152 |
+
r('copy', (np.array([1,2,3]),), {})
|
| 153 |
+
r('frombuffer', {'buffer': np.array([1,2,3])}, {})
|
| 154 |
+
m('fromfile')
|
| 155 |
+
r('fromfunction', {'function': lambda i, j: i + j, 'shape': (3, 3)}, {'dtype': np.float64})
|
| 156 |
+
# function, shape, **kwargs
|
| 157 |
+
# lambda i, j: i + j, (3, 3), dtype=int
|
| 158 |
+
r('fromiter', {'iter': [1,2,3,4], 'dtype': np.float64}, {'count': 2})
|
| 159 |
+
r('fromstring', {'string': '\x01\x02', 'dtype': np.uint8}, {})
|
| 160 |
+
r('loadtxt', {'fname': StringIO("0 1\n2 3")}, {})
|
| 161 |
+
|
| 162 |
+
hd3('Creating record arrays (wont be implemented)')
|
| 163 |
+
hd3('Creating character arrays (wont be implemented)')
|
| 164 |
+
|
| 165 |
+
hd3('Numerical ranges')
|
| 166 |
+
r('arange', {'start': 0, 'stop': 10}, {'step': 2, 'dtype': np.float64})
|
| 167 |
+
r('linspace', {'start': 0, 'stop': 10}, {'num': 2, 'endpoint': 10, 'retstep': 1})
|
| 168 |
+
r('logspace', {'start': 0, 'stop': 10}, {'num': 2, 'endpoint': 10, 'base': 1})
|
| 169 |
+
r('meshgrid', ([1,2,3], [4,5,6]), {})
|
| 170 |
+
m('mgrid')
|
| 171 |
+
m('ogrid')
|
| 172 |
+
|
| 173 |
+
hd3('Building matrices')
|
| 174 |
+
r('diag', {'v': np.arange(9).reshape((3,3))}, {'k': 0})
|
| 175 |
+
r('diagflat', {'v': [[1,2], [3,4]]}, {})
|
| 176 |
+
r('tri', {'N': 3}, {'M': 5, 'k': 2, 'dtype': np.float64})
|
| 177 |
+
r('tril', {'m': [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]}, {'k': -1})
|
| 178 |
+
r('triu', {'m': [[1,2,3],[4,5,6],[7,8,9],[10,11,12]]}, {'k': -1})
|
| 179 |
+
r('vander', {'x': np.array([1, 2, 3, 5])}, {'N': 3})
|
| 180 |
+
|
| 181 |
+
###############################
|
| 182 |
+
hd2('Array manipulation routines')
|
| 183 |
+
|
| 184 |
+
hd3('Basic operations')
|
| 185 |
+
r('copyto', {'dst': np.eye(3), 'src': np.eye(3)}, {})
|
| 186 |
+
|
| 187 |
+
hd3('Changing array shape')
|
| 188 |
+
r('reshape', {'a': np.eye(3), 'newshape': (9,)}, {'order' : 'C'})
|
| 189 |
+
r('ravel', {'a': np.eye(3)}, {'order' : 'C'})
|
| 190 |
+
m('flat')
|
| 191 |
+
m('flatten')
|
| 192 |
+
|
| 193 |
+
hd3('Transpose-like operations')
|
| 194 |
+
r('rollaxis', {'a': np.ones((3,4,5,6)), 'axis': 3}, {'start': 0})
|
| 195 |
+
r('swapaxes', {'a': np.array([[1,2,3]]), 'axis1': 0, 'axis2': 1}, {})
|
| 196 |
+
r('transpose', {'a': np.arange(4).reshape((2,2))}, {'axes': (1,0)})
|
| 197 |
+
|
| 198 |
+
hd3('Changing number of dimensions')
|
| 199 |
+
r('atleast_1d', (np.eye(3),), {})
|
| 200 |
+
r('atleast_2d', (np.eye(3),), {})
|
| 201 |
+
r('atleast_3d', (np.eye(3),), {})
|
| 202 |
+
m('broadcast')
|
| 203 |
+
m('broadcast_arrays')
|
| 204 |
+
r('expand_dims', (np.array([1,2]),2), {})
|
| 205 |
+
r('squeeze', {'a': (np.array([[[1,2,3]]]))}, {})
|
| 206 |
+
|
| 207 |
+
hd3('Changing kind of array')
|
| 208 |
+
r('asarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
|
| 209 |
+
r('asanyarray', {'a': np.array([1,2,3])}, {'dtype': np.float64, 'order': 'C'})
|
| 210 |
+
r('asmatrix', {'data': np.array([1,2,3])}, {})
|
| 211 |
+
r('asfarray', {'a': np.array([1,2,3])}, {})
|
| 212 |
+
r('asfortranarray', {'a': np.array([1,2,3])}, {})
|
| 213 |
+
r('asscalar', {'a': np.array([24])}, {})
|
| 214 |
+
r('require', {'a': np.array([24])}, {})
|
| 215 |
+
|
| 216 |
+
hd3('Joining arrays')
|
| 217 |
+
m('column_stack')
|
| 218 |
+
r('concatenate', ((np.eye(3), np.eye(3)),1), {})
|
| 219 |
+
r('dstack', ((np.eye(3), np.eye(3)),), {})
|
| 220 |
+
r('hstack', ((np.eye(3), np.eye(3)),), {})
|
| 221 |
+
r('vstack', ((np.eye(3), np.eye(3)),), {})
|
| 222 |
+
|
| 223 |
+
hd3('Splitting arrays')
|
| 224 |
+
m('array_split')
|
| 225 |
+
m('dsplit')
|
| 226 |
+
m('hsplit')
|
| 227 |
+
m('split')
|
| 228 |
+
m('vsplit')
|
| 229 |
+
|
| 230 |
+
hd3('Tiling arrays')
|
| 231 |
+
r('tile', (np.array([0, 1, 2]),2), {})
|
| 232 |
+
r('repeat', (np.array([[1,2],[3,4]]), 3), {'axis': 1})
|
| 233 |
+
|
| 234 |
+
hd3('Adding and removing elements')
|
| 235 |
+
m('delete')
|
| 236 |
+
m('insert')
|
| 237 |
+
m('append')
|
| 238 |
+
m('resize')
|
| 239 |
+
m('trim_zeros')
|
| 240 |
+
m('unique')
|
| 241 |
+
|
| 242 |
+
hd3('Rearranging elements')
|
| 243 |
+
r('fliplr', (np.eye(3),), {})
|
| 244 |
+
r('flipud', (np.eye(3),), {})
|
| 245 |
+
r('reshape', {'a': np.eye(3), 'newshape': (9,)}, {'order' : 'C'})
|
| 246 |
+
r('roll', (np.arange(10), 2), {})
|
| 247 |
+
r('rot90', (np.arange(4).reshape((2,2)),), {})
|
| 248 |
+
|
| 249 |
+
###############################
|
| 250 |
+
hd2('Linear algebra (numpy.linalg)')
|
| 251 |
+
|
| 252 |
+
extra_args = {'nplib': numpy.linalg, 'chlib': ch.linalg}
|
| 253 |
+
|
| 254 |
+
hd3('Matrix and dot products')
|
| 255 |
+
r('dot', {'a': np.eye(3), 'b': np.eye(3)}, {})
|
| 256 |
+
r('dot', {'a': np.eye(3).ravel(), 'b': np.eye(3).ravel()}, {})
|
| 257 |
+
r('vdot', (np.eye(3).ravel(), np.eye(3).ravel()), {})
|
| 258 |
+
r('inner', (np.eye(3).ravel(), np.eye(3).ravel()), {})
|
| 259 |
+
r('outer', (np.eye(3).ravel(), np.eye(3).ravel()), {})
|
| 260 |
+
r('tensordot', {'a': np.eye(3), 'b': np.eye(3)}, {})
|
| 261 |
+
m('einsum')
|
| 262 |
+
r('matrix_power', {'M': np.eye(3), 'n': 2}, {}, **extra_args)
|
| 263 |
+
r('kron', {'a': np.eye(3), 'b': np.eye(3)}, {})
|
| 264 |
+
|
| 265 |
+
hd3('Decompositions')
|
| 266 |
+
r('cholesky', {'a': np.eye(3)}, {}, **extra_args)
|
| 267 |
+
r('qr', {'a': np.eye(3)}, {}, **extra_args)
|
| 268 |
+
r('svd', (np.eye(3),), {}, **extra_args)
|
| 269 |
+
|
| 270 |
+
hd3('Matrix eigenvalues')
|
| 271 |
+
r('eig', (np.eye(3),), {}, **extra_args)
|
| 272 |
+
r('eigh', (np.eye(3),), {}, **extra_args)
|
| 273 |
+
r('eigvals', (np.eye(3),), {}, **extra_args)
|
| 274 |
+
r('eigvalsh', (np.eye(3),), {}, **extra_args)
|
| 275 |
+
|
| 276 |
+
hd3('Norms and other numbers')
|
| 277 |
+
r('norm', (np.eye(3),), {}, **extra_args)
|
| 278 |
+
r('cond', (np.eye(3),), {}, **extra_args)
|
| 279 |
+
r('det', (np.eye(3),), {}, **extra_args)
|
| 280 |
+
r('slogdet', (np.eye(3),), {}, **extra_args)
|
| 281 |
+
r('trace', (np.eye(3),), {})
|
| 282 |
+
|
| 283 |
+
hd3('Solving equations and inverting matrices')
|
| 284 |
+
r('solve', (np.eye(3),np.ones(3)), {}, **extra_args)
|
| 285 |
+
r('tensorsolve', (np.eye(3),np.ones(3)), {}, **extra_args)
|
| 286 |
+
r('lstsq', (np.eye(3),np.ones(3)), {}, **extra_args)
|
| 287 |
+
r('inv', (np.eye(3),), {}, **extra_args)
|
| 288 |
+
r('pinv', (np.eye(3),), {}, **extra_args)
|
| 289 |
+
r('tensorinv', (np.eye(4*6).reshape((4,6,8,3)),), {'ind': 2}, **extra_args)
|
| 290 |
+
|
| 291 |
+
hd2('Mathematical functions')
|
| 292 |
+
|
| 293 |
+
hd3('Trigonometric functions')
|
| 294 |
+
r('sin', (np.arange(3),), {})
|
| 295 |
+
r('cos', (np.arange(3),), {})
|
| 296 |
+
r('tan', (np.arange(3),), {})
|
| 297 |
+
r('arcsin', (np.arange(3)/3.,), {})
|
| 298 |
+
r('arccos', (np.arange(3)/3.,), {})
|
| 299 |
+
r('arctan', (np.arange(3)/3.,), {})
|
| 300 |
+
r('hypot', (np.arange(3),np.arange(3)), {})
|
| 301 |
+
r('arctan2', (np.arange(3),np.arange(3)), {})
|
| 302 |
+
r('degrees', (np.arange(3),), {})
|
| 303 |
+
r('radians', (np.arange(3),), {})
|
| 304 |
+
r('unwrap', (np.arange(3),), {})
|
| 305 |
+
r('unwrap', (np.arange(3),), {})
|
| 306 |
+
r('deg2rad', (np.arange(3),), {})
|
| 307 |
+
r('rad2deg', (np.arange(3),), {})
|
| 308 |
+
|
| 309 |
+
hd3('Hyperbolic functions')
|
| 310 |
+
r('sinh', (np.arange(3),), {})
|
| 311 |
+
r('cosh', (np.arange(3),), {})
|
| 312 |
+
r('tanh', (np.arange(3),), {})
|
| 313 |
+
r('arcsinh', (np.arange(3)/9.,), {})
|
| 314 |
+
r('arccosh', (-np.arange(3)/9.,), {})
|
| 315 |
+
r('arctanh', (np.arange(3)/9.,), {})
|
| 316 |
+
|
| 317 |
+
hd3('Rounding')
|
| 318 |
+
r('around', (np.arange(3),), {})
|
| 319 |
+
r('round_', (np.arange(3),), {})
|
| 320 |
+
r('rint', (np.arange(3),), {})
|
| 321 |
+
r('fix', (np.arange(3),), {})
|
| 322 |
+
r('floor', (np.arange(3),), {})
|
| 323 |
+
r('ceil', (np.arange(3),), {})
|
| 324 |
+
r('trunc', (np.arange(3),), {})
|
| 325 |
+
|
| 326 |
+
hd3('Sums, products, differences')
|
| 327 |
+
r('prod', (np.arange(3),), {})
|
| 328 |
+
r('sum', (np.arange(3),), {})
|
| 329 |
+
r('nansum', (np.arange(3),), {})
|
| 330 |
+
r('cumprod', (np.arange(3),), {})
|
| 331 |
+
r('cumsum', (np.arange(3),), {})
|
| 332 |
+
r('diff', (np.arange(3),), {})
|
| 333 |
+
r('ediff1d', (np.arange(3),), {})
|
| 334 |
+
r('gradient', (np.arange(3),), {})
|
| 335 |
+
r('cross', (np.arange(3), np.arange(3)), {})
|
| 336 |
+
r('trapz', (np.arange(3),), {})
|
| 337 |
+
|
| 338 |
+
hd3('Exponents and logarithms')
|
| 339 |
+
r('exp', (np.arange(3),), {})
|
| 340 |
+
r('expm1', (np.arange(3),), {})
|
| 341 |
+
r('exp2', (np.arange(3),), {})
|
| 342 |
+
r('log', (np.arange(3),), {})
|
| 343 |
+
r('log10', (np.arange(3),), {})
|
| 344 |
+
r('log2', (np.arange(3),), {})
|
| 345 |
+
r('log1p', (np.arange(3),), {})
|
| 346 |
+
r('logaddexp', (np.arange(3), np.arange(3)), {})
|
| 347 |
+
r('logaddexp2', (np.arange(3), np.arange(3)), {})
|
| 348 |
+
|
| 349 |
+
hd3('Other special functions')
|
| 350 |
+
r('i0', (np.arange(3),), {})
|
| 351 |
+
r('sinc', (np.arange(3),), {})
|
| 352 |
+
|
| 353 |
+
hd3('Floating point routines')
|
| 354 |
+
r('signbit', (np.arange(3),), {})
|
| 355 |
+
r('copysign', (np.arange(3), np.arange(3)), {})
|
| 356 |
+
r('frexp', (np.arange(3),), {})
|
| 357 |
+
r('ldexp', (np.arange(3), np.arange(3)), {})
|
| 358 |
+
|
| 359 |
+
hd3('Arithmetic operations')
|
| 360 |
+
r('add', (np.arange(3), np.arange(3)), {})
|
| 361 |
+
r('reciprocal', (np.arange(3),), {})
|
| 362 |
+
r('negative', (np.arange(3),), {})
|
| 363 |
+
r('multiply', (np.arange(3), np.arange(3)), {})
|
| 364 |
+
r('divide', (np.arange(3), np.arange(3)), {})
|
| 365 |
+
r('power', (np.arange(3), np.arange(3)), {})
|
| 366 |
+
r('subtract', (np.arange(3), np.arange(3)), {})
|
| 367 |
+
r('true_divide', (np.arange(3), np.arange(3)), {})
|
| 368 |
+
r('floor_divide', (np.arange(3), np.arange(3)), {})
|
| 369 |
+
r('fmod', (np.arange(3), np.arange(3)), {})
|
| 370 |
+
r('mod', (np.arange(3), np.arange(3)), {})
|
| 371 |
+
r('modf', (np.arange(3),), {})
|
| 372 |
+
r('remainder', (np.arange(3), np.arange(3)), {})
|
| 373 |
+
|
| 374 |
+
hd3('Handling complex numbers')
|
| 375 |
+
m('angle')
|
| 376 |
+
m('real')
|
| 377 |
+
m('imag')
|
| 378 |
+
m('conj')
|
| 379 |
+
|
| 380 |
+
hd3('Miscellaneous')
|
| 381 |
+
r('convolve', (np.arange(3), np.arange(3)), {})
|
| 382 |
+
r('clip', (np.arange(3), 0, 2), {})
|
| 383 |
+
r('sqrt', (np.arange(3),), {})
|
| 384 |
+
r('square', (np.arange(3),), {})
|
| 385 |
+
r('absolute', (np.arange(3),), {})
|
| 386 |
+
r('fabs', (np.arange(3),), {})
|
| 387 |
+
r('sign', (np.arange(3),), {})
|
| 388 |
+
r('maximum', (np.arange(3), np.arange(3)), {})
|
| 389 |
+
r('minimum', (np.arange(3), np.arange(3)), {})
|
| 390 |
+
r('fmax', (np.arange(3), np.arange(3)), {})
|
| 391 |
+
r('fmin', (np.arange(3), np.arange(3)), {})
|
| 392 |
+
r('nan_to_num', (np.arange(3),), {})
|
| 393 |
+
r('real_if_close', (np.arange(3),), {})
|
| 394 |
+
r('interp', (2.5, [1,2,3], [3,2,0]), {})
|
| 395 |
+
|
| 396 |
+
extra_args = {'nplib': numpy.random, 'chlib': ch.random}
|
| 397 |
+
|
| 398 |
+
hd2('Random sampling (numpy.random)')
|
| 399 |
+
hd3('Simple random data')
|
| 400 |
+
r('rand', (3,), {}, **extra_args)
|
| 401 |
+
r('randn', (3,), {}, **extra_args)
|
| 402 |
+
r('randint', (3,), {}, **extra_args)
|
| 403 |
+
r('random_integers', (3,), {}, **extra_args)
|
| 404 |
+
r('random_sample', (3,), {}, **extra_args)
|
| 405 |
+
r('random', (3,), {}, **extra_args)
|
| 406 |
+
r('ranf', (3,), {}, **extra_args)
|
| 407 |
+
r('sample', (3,), {}, **extra_args)
|
| 408 |
+
r('choice', (np.ones(3),), {}, **extra_args)
|
| 409 |
+
r('bytes', (3,), {}, **extra_args)
|
| 410 |
+
|
| 411 |
+
hd3('Permutations')
|
| 412 |
+
r('shuffle', (np.ones(3),), {}, **extra_args)
|
| 413 |
+
r('permutation', (3,), {}, **extra_args)
|
| 414 |
+
|
| 415 |
+
hd3('Distributions (these all pass)')
|
| 416 |
+
r('beta', (.5, .5), {}, **extra_args)
|
| 417 |
+
r('binomial', (.5, .5), {}, **extra_args)
|
| 418 |
+
r('chisquare', (.5,), {}, **extra_args)
|
| 419 |
+
r('dirichlet', ((10, 5, 3), 20,), {}, **extra_args)
|
| 420 |
+
r('exponential', [], {}, **extra_args)
|
| 421 |
+
r('f', [1,48,1000], {}, **extra_args)
|
| 422 |
+
r('gamma', [.5], {}, **extra_args)
|
| 423 |
+
make_row('...AND 28 OTHERS...', 'passed', 'passed', 'lightgreen', 'lightgreen')
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
hd3('Random generator')
|
| 427 |
+
r('seed', [], {}, **extra_args)
|
| 428 |
+
r('get_state', [], {}, **extra_args)
|
| 429 |
+
r('set_state', [np.random.get_state()], {}, **extra_args)
|
| 430 |
+
|
| 431 |
+
####################################
|
| 432 |
+
hd2('Statistics')
|
| 433 |
+
hd3('Order statistics')
|
| 434 |
+
r('amin', (np.eye(3),),{})
|
| 435 |
+
r('amax', (np.eye(3),),{})
|
| 436 |
+
r('nanmin', (np.eye(3),),{})
|
| 437 |
+
r('nanmax', (np.eye(3),),{})
|
| 438 |
+
r('ptp', (np.eye(3),),{})
|
| 439 |
+
r('percentile', (np.eye(3),50),{})
|
| 440 |
+
|
| 441 |
+
hd3('Averages and variance')
|
| 442 |
+
r('median', (np.eye(3),),{})
|
| 443 |
+
r('average', (np.eye(3),),{})
|
| 444 |
+
r('mean', (np.eye(3),),{})
|
| 445 |
+
r('std', (np.eye(3),),{})
|
| 446 |
+
r('var', (np.eye(3),),{})
|
| 447 |
+
r('nanmean', (np.eye(3),),{})
|
| 448 |
+
r('nanstd', (np.eye(3),),{})
|
| 449 |
+
r('nanvar', (np.eye(3),),{})
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
hd3('Correlating')
|
| 453 |
+
r('corrcoef', (np.eye(3),),{})
|
| 454 |
+
r('correlate', ([1, 2, 3], [0, 1, 0.5]),{})
|
| 455 |
+
r('cov', (np.eye(3),),{})
|
| 456 |
+
|
| 457 |
+
hd3('Histograms')
|
| 458 |
+
r('histogram', (np.eye(3),),{})
|
| 459 |
+
r('histogram2d', (np.eye(3).ravel(),np.eye(3).ravel()),{})
|
| 460 |
+
r('histogramdd', (np.eye(3).ravel(),),{})
|
| 461 |
+
r('bincount', (np.asarray(np.eye(3).ravel(), np.uint32),),{})
|
| 462 |
+
r('digitize', (np.array([0.2, 6.4, 3.0, 1.6]), np.array([0.0, 1.0, 2.5, 4.0, 10.0])),{})
|
| 463 |
+
|
| 464 |
+
####################################
|
| 465 |
+
hd2('Sorting, searching, and counting')
|
| 466 |
+
|
| 467 |
+
hd3('Sorting')
|
| 468 |
+
r('sort', (np.array([1,3,1,2.]),), {})
|
| 469 |
+
m('lexsort')
|
| 470 |
+
m('argsort')
|
| 471 |
+
m('msort')
|
| 472 |
+
m('sort_complex')
|
| 473 |
+
m('partition')
|
| 474 |
+
m('argpartition')
|
| 475 |
+
|
| 476 |
+
# sort(a[, axis, kind, order]) Return a sorted copy of an array.
|
| 477 |
+
# lexsort(keys[, axis]) Perform an indirect sort using a sequence of keys.
|
| 478 |
+
# argsort(a[, axis, kind, order]) Returns the indices that would sort an array.
|
| 479 |
+
# ndarray.sort([axis, kind, order]) Sort an array, in-place.
|
| 480 |
+
# msort(a) Return a copy of an array sorted along the first axis.
|
| 481 |
+
# sort_complex(a) Sort a complex array using the real part first, then the imaginary part.
|
| 482 |
+
# partition(a, kth[, axis, kind, order]) Return a partitioned copy of an array.
|
| 483 |
+
# argpartition(a, kth[, axis, kind, order]) Perform an indirect partition along the given axis using the algorithm specified by the kind keyword.
|
| 484 |
+
|
| 485 |
+
a5 = np.arange(5)
|
| 486 |
+
|
| 487 |
+
hd3('Searching')
|
| 488 |
+
r('argmax', (a5,), {})
|
| 489 |
+
r('nanargmax', (a5,), {})
|
| 490 |
+
r('argmin', (a5,), {})
|
| 491 |
+
r('nanargmin', (a5,), {})
|
| 492 |
+
r('argwhere', (a5,), {})
|
| 493 |
+
r('nonzero', (a5,), {})
|
| 494 |
+
r('flatnonzero', (a5,), {})
|
| 495 |
+
r('where', (a5>1,), {})
|
| 496 |
+
r('searchsorted', (a5,a5), {})
|
| 497 |
+
r('extract', (lambda x : x > 1, a5), {})
|
| 498 |
+
|
| 499 |
+
# argmax(a[, axis]) Indices of the maximum values along an axis.
|
| 500 |
+
# nanargmax(a[, axis]) Return the indices of the maximum values in the specified axis ignoring
|
| 501 |
+
# argmin(a[, axis]) Return the indices of the minimum values along an axis.
|
| 502 |
+
# nanargmin(a[, axis]) Return the indices of the minimum values in the specified axis ignoring
|
| 503 |
+
# argwhere(a) Find the indices of array elements that are non-zero, grouped by element.
|
| 504 |
+
# nonzero(a) Return the indices of the elements that are non-zero.
|
| 505 |
+
# flatnonzero(a) Return indices that are non-zero in the flattened version of a.
|
| 506 |
+
# where(condition, [x, y]) Return elements, either from x or y, depending on condition.
|
| 507 |
+
# searchsorted(a, v[, side, sorter]) Find indices where elements should be inserted to maintain order.
|
| 508 |
+
# extract(condition, arr) Return the elements of an array that satisfy some condition.
|
| 509 |
+
|
| 510 |
+
hd3('Counting')
|
| 511 |
+
r('count_nonzero', (a5,), {})
|
| 512 |
+
#count_nonzero(a) Counts the number of non-zero values in the array a.
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# histogram(a[, bins, range, normed, weights, ...]) Compute the histogram of a set of data.
|
| 517 |
+
# histogram2d(x, y[, bins, range, normed, weights]) Compute the bi-dimensional histogram of two data samples.
|
| 518 |
+
# histogramdd(sample[, bins, range, normed, ...]) Compute the multidimensional histogram of some data.
|
| 519 |
+
# bincount(x[, weights, minlength]) Count number of occurrences of each value in array of non-negative ints.
|
| 520 |
+
# digitize(x, bins[, right]) Return the indices of the bins to which each value in input array belongs.
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
global src
|
| 524 |
+
src = '<html><body><table border=1>' + src + '</table></body></html>'
|
| 525 |
+
open(join(split(__file__)[0], 'api_compatibility.html'), 'w').write(src)
|
| 526 |
+
|
| 527 |
+
print('passed %d, not passed %d' % (num_passed, num_not_passed))
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
if __name__ == '__main__':
|
| 532 |
+
global which_passed
|
| 533 |
+
main()
|
| 534 |
+
print(' '.join(which_passed))
|
vendor/chumpy/chumpy/ch.py
ADDED
|
@@ -0,0 +1,1367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = ['Ch', 'depends_on', 'MatVecMult', 'ChHandle', 'ChLambda']
|
| 11 |
+
|
| 12 |
+
import os, sys, time
|
| 13 |
+
import inspect
|
| 14 |
+
import scipy.sparse as sp
|
| 15 |
+
import numpy as np
|
| 16 |
+
import numbers
|
| 17 |
+
import weakref
|
| 18 |
+
import copy as external_copy
|
| 19 |
+
from functools import wraps
|
| 20 |
+
from scipy.sparse.linalg.interface import LinearOperator
|
| 21 |
+
from .utils import row, col, timer, convert_inputs_to_sparse_if_necessary
|
| 22 |
+
import collections
|
| 23 |
+
from copy import deepcopy
|
| 24 |
+
from functools import reduce
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Turn this on if you want the profiler injected
|
| 29 |
+
DEBUG = False
|
| 30 |
+
# Turn this on to make optimizations very chatty for debugging
|
| 31 |
+
VERBOSE = False
|
| 32 |
+
def pif(msg):
|
| 33 |
+
# print-if-verbose.
|
| 34 |
+
if DEBUG or VERBOSE:
|
| 35 |
+
sys.stdout.write(msg + '\n')
|
| 36 |
+
|
| 37 |
+
_props_for_dict = weakref.WeakKeyDictionary()
|
| 38 |
+
def _props_for(cls):
|
| 39 |
+
if cls not in _props_for_dict:
|
| 40 |
+
_props_for_dict[cls] = set([p[0] for p in inspect.getmembers(cls, lambda x : isinstance(x, property))])
|
| 41 |
+
return _props_for_dict[cls]
|
| 42 |
+
|
| 43 |
+
_dep_props_for_dict = weakref.WeakKeyDictionary()
|
| 44 |
+
def _dep_props_for(cls):
|
| 45 |
+
if cls not in _dep_props_for_dict:
|
| 46 |
+
_dep_props_for_dict[cls] = [p for p in inspect.getmembers(cls, lambda x : isinstance(x, property)) if hasattr(p[1].fget, 'deps')]
|
| 47 |
+
return _dep_props_for_dict[cls]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
_kw_conflict_dict = weakref.WeakKeyDictionary()
|
| 51 |
+
def _check_kw_conflict(cls):
|
| 52 |
+
if cls not in _kw_conflict_dict:
|
| 53 |
+
_kw_conflict_dict[cls] = Ch._reserved_kw.intersection(set(cls.terms).union(set(cls.dterms)))
|
| 54 |
+
if _kw_conflict_dict[cls]:
|
| 55 |
+
raise Exception("In class %s, don't use reserved keywords in terms/dterms: %s" % (str(cls), str(kw_conflict),))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Term(object):
|
| 59 |
+
creation_counter = 0
|
| 60 |
+
def __init__(self, default=None, desc=None, dr=True):
|
| 61 |
+
self.default = default
|
| 62 |
+
self.desc = desc
|
| 63 |
+
self.dr = dr
|
| 64 |
+
|
| 65 |
+
# Add a creation_counter, a la Django models, so we can preserve the order in which parameters are defined in the job.
|
| 66 |
+
# http://stackoverflow.com/a/3288801/893113
|
| 67 |
+
self.creation_counter = Term.creation_counter
|
| 68 |
+
Term.creation_counter += 1
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Ch(object):
|
| 72 |
+
terms = []
|
| 73 |
+
dterms = ['x']
|
| 74 |
+
__array_priority__ = 2.0
|
| 75 |
+
_cached_parms = {}
|
| 76 |
+
_setup_terms = {}
|
| 77 |
+
_default_kwargs = {'make_dense' : False, 'make_sparse' : False}
|
| 78 |
+
_status = "undefined"
|
| 79 |
+
|
| 80 |
+
called_dr_wrt = False
|
| 81 |
+
profiler = None
|
| 82 |
+
|
| 83 |
+
########################################################
|
| 84 |
+
# Construction
|
| 85 |
+
|
| 86 |
+
def __new__(cls, *args, **kwargs):
|
| 87 |
+
|
| 88 |
+
if len(args) > 0 and type(args[0]) == type(lambda : 0):
|
| 89 |
+
cls = ChLambda
|
| 90 |
+
|
| 91 |
+
# Create empty instance
|
| 92 |
+
result = super(Ch, cls).__new__(cls)
|
| 93 |
+
|
| 94 |
+
cls.setup_terms()
|
| 95 |
+
|
| 96 |
+
object.__setattr__(result, '_dirty_vars', set())
|
| 97 |
+
object.__setattr__(result, '_itr', None)
|
| 98 |
+
object.__setattr__(result, '_parents', weakref.WeakKeyDictionary())
|
| 99 |
+
object.__setattr__(result, '_cache', {'r': None, 'drs': weakref.WeakKeyDictionary()})
|
| 100 |
+
|
| 101 |
+
if DEBUG:
|
| 102 |
+
object.__setattr__(result, '_cache_info', {})
|
| 103 |
+
object.__setattr__(result, '_status', 'new')
|
| 104 |
+
|
| 105 |
+
for name, default_val in list(cls._default_kwargs.items()):
|
| 106 |
+
object.__setattr__(result, '_%s' % name, kwargs.get(name, default_val))
|
| 107 |
+
if name in kwargs:
|
| 108 |
+
del kwargs[name]
|
| 109 |
+
|
| 110 |
+
# Set up storage that allows @depends_on to work
|
| 111 |
+
#props = [p for p in inspect.getmembers(cls, lambda x : isinstance(x, property)) if hasattr(p[1].fget, 'deps')]
|
| 112 |
+
props = _dep_props_for(cls)
|
| 113 |
+
cpd = {}
|
| 114 |
+
for p in props:
|
| 115 |
+
func_name = p[0] #id(p[1].fget)
|
| 116 |
+
deps = p[1].fget.deps
|
| 117 |
+
cpd[func_name] = {'deps': deps, 'value': None, 'out_of_date': True}
|
| 118 |
+
|
| 119 |
+
object.__setattr__(result, '_depends_on_deps', cpd)
|
| 120 |
+
|
| 121 |
+
if cls != Ch:
|
| 122 |
+
for idx, a in enumerate(args):
|
| 123 |
+
kwargs[cls.term_order[idx]] = a
|
| 124 |
+
elif len(args)>0:
|
| 125 |
+
kwargs['x'] = np.asarray(args[0], np.float64)
|
| 126 |
+
|
| 127 |
+
defs = {p.name : deepcopy(p.default) for p in cls.parm_declarations() if p.default is not None}
|
| 128 |
+
defs.update(kwargs)
|
| 129 |
+
result.set(**defs)
|
| 130 |
+
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def parm_declarations(cls):
|
| 135 |
+
if cls.__name__ not in cls._cached_parms:
|
| 136 |
+
parameter_declarations = collections.OrderedDict()
|
| 137 |
+
parameters = inspect.getmembers(cls, lambda x: isinstance(x, Term))
|
| 138 |
+
for name, decl in sorted(parameters, key=lambda x: x[1].creation_counter):
|
| 139 |
+
decl.name = name
|
| 140 |
+
parameter_declarations[name] = decl
|
| 141 |
+
cls._cached_parms[cls.__name__] = parameter_declarations
|
| 142 |
+
return cls._cached_parms[cls.__name__]
|
| 143 |
+
|
| 144 |
+
@classmethod
|
| 145 |
+
def setup_terms(cls):
|
| 146 |
+
if id(cls) in cls._setup_terms: return
|
| 147 |
+
|
| 148 |
+
if cls == Ch:
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
parm_declarations = cls.parm_declarations()
|
| 152 |
+
|
| 153 |
+
if cls.dterms is Ch.dterms:
|
| 154 |
+
cls.dterms = []
|
| 155 |
+
elif isinstance(cls.dterms, str):
|
| 156 |
+
cls.dterms = (cls.dterms,)
|
| 157 |
+
if cls.terms is Ch.terms:
|
| 158 |
+
cls.terms = []
|
| 159 |
+
elif isinstance(cls.terms, str):
|
| 160 |
+
cls.terms = (cls.terms,)
|
| 161 |
+
|
| 162 |
+
# Must be either new or old style
|
| 163 |
+
len_oldstyle_parms = len(cls.dterms)+len(cls.terms)
|
| 164 |
+
if len(parm_declarations) > 0:
|
| 165 |
+
assert(len_oldstyle_parms==0)
|
| 166 |
+
cls.term_order = [t.name for t in parm_declarations]
|
| 167 |
+
cls.dterms = [t.name for t in parm_declarations if t.dr]
|
| 168 |
+
cls.terms = [t.name for t in parm_declarations if not t.dr]
|
| 169 |
+
else:
|
| 170 |
+
if not hasattr(cls, 'term_order'):
|
| 171 |
+
cls.term_order = list(cls.terms) + list(cls.dterms)
|
| 172 |
+
|
| 173 |
+
_check_kw_conflict(cls)
|
| 174 |
+
cls._setup_terms[id(cls)] = True
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
########################################################
|
| 178 |
+
# Identifiers
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def short_name(self):
|
| 182 |
+
return self.label if hasattr(self, 'label') else self.__class__.__name__
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def sid(self):
|
| 186 |
+
"""Semantic id."""
|
| 187 |
+
pnames = list(self.terms)+list(self.dterms)
|
| 188 |
+
pnames.sort()
|
| 189 |
+
return (self.__class__, tuple([(k, id(self.__dict__[k])) for k in pnames if k in self.__dict__]))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def reshape(self, *args):
|
| 193 |
+
return reshape(a=self, newshape=args if len(args)>1 else args[0])
|
| 194 |
+
|
| 195 |
+
def ravel(self):
|
| 196 |
+
return reshape(a=self, newshape=(-1))
|
| 197 |
+
|
| 198 |
+
def __hash__(self):
|
| 199 |
+
return id(self)
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def ndim(self):
|
| 203 |
+
return self.r.ndim
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def flat(self):
|
| 207 |
+
return self.r.flat
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def dtype(self):
|
| 211 |
+
return self.r.dtype
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def itemsize(self):
|
| 215 |
+
return self.r.itemsize
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
########################################################
|
| 219 |
+
# Redundancy removal
|
| 220 |
+
|
| 221 |
+
def remove_redundancy(self, cache=None, iterate=True):
|
| 222 |
+
|
| 223 |
+
if cache == None:
|
| 224 |
+
cache = {}
|
| 225 |
+
_ = self.r # may result in the creation of extra dterms that we can cull
|
| 226 |
+
|
| 227 |
+
replacement_occurred = False
|
| 228 |
+
for propname in list(self.dterms):
|
| 229 |
+
prop = self.__dict__[propname]
|
| 230 |
+
|
| 231 |
+
if not hasattr(prop, 'dterms'):
|
| 232 |
+
continue
|
| 233 |
+
sid = prop.sid
|
| 234 |
+
if sid not in cache:
|
| 235 |
+
cache[sid] = prop
|
| 236 |
+
elif self.__dict__[propname] is not cache[sid]:
|
| 237 |
+
self.__dict__[propname] = cache[sid]
|
| 238 |
+
replacement_occurred = True
|
| 239 |
+
if prop.remove_redundancy(cache, iterate=False):
|
| 240 |
+
replacement_occurred = True
|
| 241 |
+
|
| 242 |
+
if not replacement_occurred:
|
| 243 |
+
return False
|
| 244 |
+
else:
|
| 245 |
+
if iterate:
|
| 246 |
+
self.remove_redundancy(cache, iterate=True)
|
| 247 |
+
return False
|
| 248 |
+
else:
|
| 249 |
+
return True
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def print_labeled_residuals(self, print_newline=True, num_decimals=2, where_to_print=None):
|
| 254 |
+
|
| 255 |
+
if where_to_print is None:
|
| 256 |
+
where_to_print = sys.stderr
|
| 257 |
+
if hasattr(self, 'label'):
|
| 258 |
+
where_to_print.write(('%s: %.' + str(num_decimals) + 'e | ') % (self.label, np.sum(self.r**2)))
|
| 259 |
+
for dterm in self.dterms:
|
| 260 |
+
dt = getattr(self, dterm)
|
| 261 |
+
if hasattr(dt, 'dterms'):
|
| 262 |
+
dt.print_labeled_residuals(print_newline=False, where_to_print=where_to_print)
|
| 263 |
+
if print_newline:
|
| 264 |
+
where_to_print.write(('%.' + str(num_decimals) + 'e\n') % (np.sum(self.r**2),))
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
########################################################
|
| 269 |
+
# Default methods, for when Ch is not subclassed
|
| 270 |
+
|
| 271 |
+
def compute_r(self):
|
| 272 |
+
"""Default method for objects that just contain a number or ndarray"""
|
| 273 |
+
return self.x
|
| 274 |
+
|
| 275 |
+
def compute_dr_wrt(self,wrt):
|
| 276 |
+
"""Default method for objects that just contain a number or ndarray"""
|
| 277 |
+
if wrt is self: # special base case
|
| 278 |
+
return sp.eye(self.x.size, self.x.size)
|
| 279 |
+
#return np.array([[1]])
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _compute_dr_wrt_sliced(self, wrt):
|
| 284 |
+
self._call_on_changed()
|
| 285 |
+
|
| 286 |
+
# if wrt is self:
|
| 287 |
+
# return np.array([[1]])
|
| 288 |
+
|
| 289 |
+
result = self.compute_dr_wrt(wrt)
|
| 290 |
+
if result is not None:
|
| 291 |
+
return result
|
| 292 |
+
|
| 293 |
+
# What allows slicing.
|
| 294 |
+
if True:
|
| 295 |
+
inner = wrt
|
| 296 |
+
while issubclass(inner.__class__, Permute):
|
| 297 |
+
inner = inner.a
|
| 298 |
+
if inner is self:
|
| 299 |
+
return None
|
| 300 |
+
result = self.compute_dr_wrt(inner)
|
| 301 |
+
|
| 302 |
+
if result is not None:
|
| 303 |
+
break
|
| 304 |
+
|
| 305 |
+
if result is None:
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
wrt._call_on_changed()
|
| 309 |
+
|
| 310 |
+
jac = wrt.compute_dr_wrt(inner).T
|
| 311 |
+
|
| 312 |
+
return self._superdot(result, jac)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@property
|
| 316 |
+
def shape(self):
|
| 317 |
+
return self.r.shape
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def size(self):
|
| 321 |
+
#return self.r.size
|
| 322 |
+
return np.prod(self.shape) # may be cheaper since it doesn't always mean grabbing "r"
|
| 323 |
+
|
| 324 |
+
def __len__(self):
|
| 325 |
+
return len(self.r)
|
| 326 |
+
|
| 327 |
+
def minimize(self, *args, **kwargs):
|
| 328 |
+
from . import optimization
|
| 329 |
+
return optimization.minimize(self, *args, **kwargs)
|
| 330 |
+
|
| 331 |
+
def __array__(self, *args):
|
| 332 |
+
return self.r
|
| 333 |
+
|
| 334 |
+
########################################################
|
| 335 |
+
# State management
|
| 336 |
+
|
| 337 |
+
def add_dterm(self, dterm_name, dterm):
|
| 338 |
+
self.dterms = list(set(list(self.dterms) + [dterm_name]))
|
| 339 |
+
setattr(self, dterm_name, dterm)
|
| 340 |
+
|
| 341 |
+
def copy(self):
|
| 342 |
+
return copy(self)
|
| 343 |
+
|
| 344 |
+
def __getstate__(self):
|
| 345 |
+
# Have to get rid of WeakKeyDictionaries for serialization
|
| 346 |
+
result = external_copy.copy(self.__dict__)
|
| 347 |
+
del result['_parents']
|
| 348 |
+
del result['_cache']
|
| 349 |
+
return result
|
| 350 |
+
|
| 351 |
+
def __setstate__(self, d):
|
| 352 |
+
# Restore unpickleable WeakKeyDictionaries
|
| 353 |
+
d['_parents'] = weakref.WeakKeyDictionary()
|
| 354 |
+
d['_cache'] = {'r': None, 'drs': weakref.WeakKeyDictionary()}
|
| 355 |
+
object.__setattr__(self, '__dict__', d)
|
| 356 |
+
|
| 357 |
+
# This restores our unpickleable "_parents" attribute
|
| 358 |
+
for k in set(self.dterms).intersection(set(self.__dict__.keys())):
|
| 359 |
+
setattr(self, k, self.__dict__[k])
|
| 360 |
+
|
| 361 |
+
def __setattr__(self, name, value, itr=None):
|
| 362 |
+
#print 'SETTING %s' % (name,)
|
| 363 |
+
|
| 364 |
+
# Faster path for basic Ch objects. Not necessary for functionality,
|
| 365 |
+
# but improves performance by a small amount.
|
| 366 |
+
if type(self) == Ch:
|
| 367 |
+
if name == 'x':
|
| 368 |
+
self._dirty_vars.add(name)
|
| 369 |
+
self.clear_cache(itr)
|
| 370 |
+
#else:
|
| 371 |
+
# import warnings
|
| 372 |
+
# warnings.warn('Trying to set attribute %s on a basic Ch object? Might be a mistake.' % (name,))
|
| 373 |
+
|
| 374 |
+
object.__setattr__(self, name, value)
|
| 375 |
+
return
|
| 376 |
+
|
| 377 |
+
name_in_dterms = name in self.dterms
|
| 378 |
+
name_in_terms = name in self.terms
|
| 379 |
+
name_in_props = name in _props_for(self.__class__)# [p[0] for p in inspect.getmembers(self.__class__, lambda x : isinstance(x, property))]
|
| 380 |
+
|
| 381 |
+
if name_in_dterms and not name_in_props and type(self) != Ch:
|
| 382 |
+
if not hasattr(value, 'dterms'):
|
| 383 |
+
value = Ch(value)
|
| 384 |
+
|
| 385 |
+
# Make ourselves not the parent of the old value
|
| 386 |
+
if hasattr(self, name):
|
| 387 |
+
term = getattr(self, name)
|
| 388 |
+
if self in term._parents:
|
| 389 |
+
term._parents[self]['varnames'].remove(name)
|
| 390 |
+
if len(term._parents[self]['varnames']) == 0:
|
| 391 |
+
del term._parents[self]
|
| 392 |
+
|
| 393 |
+
# Make ourselves parents of the new value
|
| 394 |
+
if self not in value._parents:
|
| 395 |
+
value._parents[self] = {'varnames': set([name])}
|
| 396 |
+
else:
|
| 397 |
+
value._parents[self]['varnames'].add(name)
|
| 398 |
+
|
| 399 |
+
if name_in_dterms or name_in_terms:
|
| 400 |
+
self._dirty_vars.add(name)
|
| 401 |
+
self._invalidate_cacheprop_names([name])
|
| 402 |
+
|
| 403 |
+
# If one of our terms has changed, it has the capacity to have
|
| 404 |
+
# changed our result and all our derivatives wrt everything
|
| 405 |
+
self.clear_cache(itr)
|
| 406 |
+
|
| 407 |
+
object.__setattr__(self, name, value)
|
| 408 |
+
|
| 409 |
+
def _invalidate_cacheprop_names(self, names):
|
| 410 |
+
nameset = set(names)
|
| 411 |
+
for func_name, v in list(self._depends_on_deps.items()):
|
| 412 |
+
if len(nameset.intersection(v['deps'])) > 0:
|
| 413 |
+
v['out_of_date'] = True
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def clear_cache(self, itr=None):
|
| 417 |
+
todo = [self]
|
| 418 |
+
done = set([])
|
| 419 |
+
nodes_visited = 0
|
| 420 |
+
while len(todo) > 0:
|
| 421 |
+
nodes_visited += 1
|
| 422 |
+
next = todo.pop()
|
| 423 |
+
if itr is not None and itr==next._itr:
|
| 424 |
+
continue
|
| 425 |
+
if id(next) not in done:
|
| 426 |
+
next._cache['r'] = None
|
| 427 |
+
next._cache['drs'].clear()
|
| 428 |
+
next._itr = itr
|
| 429 |
+
|
| 430 |
+
for parent, parent_dict in list(next._parents.items()):
|
| 431 |
+
object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
|
| 432 |
+
parent._invalidate_cacheprop_names(parent_dict['varnames'])
|
| 433 |
+
todo.append(parent)
|
| 434 |
+
done.add(id(next))
|
| 435 |
+
return nodes_visited
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def clear_cache_wrt(self, wrt, itr=None):
|
| 439 |
+
if wrt in self._cache['drs']:
|
| 440 |
+
self._cache['drs'][wrt] = None
|
| 441 |
+
|
| 442 |
+
if hasattr(self, 'dr_cached') and wrt in self.dr_cached:
|
| 443 |
+
self.dr_cached[wrt] = None
|
| 444 |
+
|
| 445 |
+
if itr is None or itr != self._itr:
|
| 446 |
+
for parent, parent_dict in list(self._parents.items()):
|
| 447 |
+
if wrt in parent._cache['drs'] or (hasattr(parent, 'dr_cached') and wrt in parent.dr_cached):
|
| 448 |
+
parent.clear_cache_wrt(wrt=wrt, itr=itr)
|
| 449 |
+
object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
|
| 450 |
+
parent._invalidate_cacheprop_names(parent_dict['varnames'])
|
| 451 |
+
|
| 452 |
+
object.__setattr__(self, '_itr', itr)
|
| 453 |
+
|
| 454 |
+
def replace(self, old, new):
|
| 455 |
+
if (hasattr(old, 'dterms') != hasattr(new, 'dterms')):
|
| 456 |
+
raise Exception('Either "old" and "new" must both be "Ch", or they must both be neither.')
|
| 457 |
+
|
| 458 |
+
for term_name in [t for t in list(self.dterms)+list(self.terms) if hasattr(self, t)]:
|
| 459 |
+
term = getattr(self, term_name)
|
| 460 |
+
if term is old:
|
| 461 |
+
setattr(self, term_name, new)
|
| 462 |
+
elif hasattr(term, 'dterms'):
|
| 463 |
+
term.replace(old, new)
|
| 464 |
+
return new
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def set(self, **kwargs):
|
| 468 |
+
# Some dterms may be aliases via @property.
|
| 469 |
+
# We want to set those last, in case they access non-property members
|
| 470 |
+
#props = [p[0] for p in inspect.getmembers(self.__class__, lambda x : isinstance(x, property))]
|
| 471 |
+
props = _props_for(self.__class__)
|
| 472 |
+
kwarg_keys = set(kwargs.keys())
|
| 473 |
+
kwsecond = kwarg_keys.intersection(props)
|
| 474 |
+
kwfirst = kwarg_keys.difference(kwsecond)
|
| 475 |
+
kwall = list(kwfirst) + list(kwsecond)
|
| 476 |
+
|
| 477 |
+
# The complexity here comes because we wish to
|
| 478 |
+
# avoid clearing cache redundantly
|
| 479 |
+
if len(kwall) > 0:
|
| 480 |
+
for k in kwall[:-1]:
|
| 481 |
+
self.__setattr__(k, kwargs[k], 9999)
|
| 482 |
+
self.__setattr__(kwall[-1], kwargs[kwall[-1]], None)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def is_dr_wrt(self, wrt):
|
| 486 |
+
if type(self) == Ch:
|
| 487 |
+
return wrt is self
|
| 488 |
+
dterms_we_have = [getattr(self, dterm) for dterm in self.dterms if hasattr(self, dterm)]
|
| 489 |
+
return wrt in dterms_we_have or any([d.is_dr_wrt(wrt) for d in dterms_we_have])
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def is_ch_baseclass(self):
|
| 493 |
+
return self.__class__ is Ch
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
########################################################
|
| 497 |
+
# Getters for our outputs
|
| 498 |
+
|
| 499 |
+
def __getitem__(self, key):
|
| 500 |
+
shape = self.shape
|
| 501 |
+
tmp = np.arange(np.prod(shape)).reshape(shape).__getitem__(key)
|
| 502 |
+
idxs = tmp.ravel()
|
| 503 |
+
newshape = tmp.shape
|
| 504 |
+
return Select(a=self, idxs=idxs, preferred_shape=newshape)
|
| 505 |
+
|
| 506 |
+
def __setitem__(self, key, value, itr=None):
|
| 507 |
+
|
| 508 |
+
if hasattr(value, 'dterms'):
|
| 509 |
+
raise Exception("Can't assign a Ch objects as a subset of another.")
|
| 510 |
+
if type(self) == Ch:# self.is_ch_baseclass():
|
| 511 |
+
data = np.atleast_1d(self.x)
|
| 512 |
+
data.__setitem__(key, value)
|
| 513 |
+
self.__setattr__('x', data, itr=itr)
|
| 514 |
+
return
|
| 515 |
+
# elif False: # Interesting but flawed idea
|
| 516 |
+
# parents = [self.__dict__[k] for k in self.dterms]
|
| 517 |
+
# kids = []
|
| 518 |
+
# while len(parents)>0:
|
| 519 |
+
# p = parents.pop()
|
| 520 |
+
# if p.is_ch_baseclass():
|
| 521 |
+
# kids.append(p)
|
| 522 |
+
# else:
|
| 523 |
+
# parents += [p.__dict__[k] for k in p.dterms]
|
| 524 |
+
# from ch.optimization import minimize_dogleg
|
| 525 |
+
# minimize_dogleg(obj=self.__getitem__(key) - value, free_variables=kids, show_residuals=False)
|
| 526 |
+
else:
|
| 527 |
+
inner = self
|
| 528 |
+
while not inner.is_ch_baseclass():
|
| 529 |
+
if issubclass(inner.__class__, Permute):
|
| 530 |
+
inner = inner.a
|
| 531 |
+
else:
|
| 532 |
+
raise Exception("Can't set array that is function of arrays.")
|
| 533 |
+
|
| 534 |
+
self = self[key]
|
| 535 |
+
dr = self.dr_wrt(inner)
|
| 536 |
+
dr_rev = dr.T
|
| 537 |
+
#dr_rev = np.linalg.pinv(dr)
|
| 538 |
+
inner_shape = inner.shape
|
| 539 |
+
|
| 540 |
+
t1 = self._superdot(dr_rev, np.asarray(value).ravel())
|
| 541 |
+
t2 = self._superdot(dr_rev, self._superdot(dr, inner.x.ravel()))
|
| 542 |
+
if sp.issparse(t1): t1 = np.array(t1.todense())
|
| 543 |
+
if sp.issparse(t2): t2 = np.array(t2.todense())
|
| 544 |
+
|
| 545 |
+
inner.x = inner.x + t1.reshape(inner_shape) - t2.reshape(inner_shape)
|
| 546 |
+
#inner.x = inner.x + self._superdot(dr_rev, value.ravel()).reshape(inner_shape) - self._superdot(dr_rev, self._superdot(dr, inner.x.ravel())).reshape(inner_shape)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def __str__(self):
|
| 550 |
+
return str(self.r)
|
| 551 |
+
|
| 552 |
+
def __repr__(self):
|
| 553 |
+
return object.__repr__(self) + '\n' + str(self.r)
|
| 554 |
+
|
| 555 |
+
def __float__(self):
|
| 556 |
+
return self.r.__float__()
|
| 557 |
+
|
| 558 |
+
def __int__(self):
|
| 559 |
+
return self.r.__int__()
|
| 560 |
+
|
| 561 |
+
def on_changed(self, terms):
|
| 562 |
+
pass
|
| 563 |
+
|
| 564 |
+
@property
|
| 565 |
+
def T(self):
|
| 566 |
+
return transpose(self)
|
| 567 |
+
|
| 568 |
+
def transpose(self, *axes):
|
| 569 |
+
return transpose(self, *axes)
|
| 570 |
+
|
| 571 |
+
def squeeze(self, axis=None):
|
| 572 |
+
return squeeze(self, axis)
|
| 573 |
+
|
| 574 |
+
def mean(self, axis=None):
|
| 575 |
+
return mean(self, axis=axis)
|
| 576 |
+
|
| 577 |
+
def sum(self, axis=None):
|
| 578 |
+
return sum(self, axis=axis)
|
| 579 |
+
|
| 580 |
+
def _call_on_changed(self):
|
| 581 |
+
|
| 582 |
+
if hasattr(self, 'is_valid'):
|
| 583 |
+
validity, msg = self.is_valid()
|
| 584 |
+
assert validity, msg
|
| 585 |
+
if hasattr(self, '_status'):
|
| 586 |
+
self._status = 'new'
|
| 587 |
+
|
| 588 |
+
if len(self._dirty_vars) > 0:
|
| 589 |
+
self.on_changed(self._dirty_vars)
|
| 590 |
+
object.__setattr__(self, '_dirty_vars', set())
|
| 591 |
+
|
| 592 |
+
@property
|
| 593 |
+
def r(self):
|
| 594 |
+
self._call_on_changed()
|
| 595 |
+
if self._cache['r'] is None:
|
| 596 |
+
self._cache['r'] = np.asarray(np.atleast_1d(self.compute_r()), dtype=np.float64, order='C')
|
| 597 |
+
self._cache['rview'] = self._cache['r'].view()
|
| 598 |
+
self._cache['rview'].flags.writeable = False
|
| 599 |
+
|
| 600 |
+
return self._cache['rview']
|
| 601 |
+
|
| 602 |
+
def _superdot(self, lhs, rhs, profiler=None):
|
| 603 |
+
|
| 604 |
+
try:
|
| 605 |
+
if lhs is None:
|
| 606 |
+
return None
|
| 607 |
+
if rhs is None:
|
| 608 |
+
return None
|
| 609 |
+
|
| 610 |
+
if isinstance(lhs, np.ndarray) and lhs.size==1:
|
| 611 |
+
lhs = lhs.ravel()[0]
|
| 612 |
+
|
| 613 |
+
if isinstance(rhs, np.ndarray) and rhs.size==1:
|
| 614 |
+
rhs = rhs.ravel()[0]
|
| 615 |
+
|
| 616 |
+
if isinstance(lhs, numbers.Number) or isinstance(rhs, numbers.Number):
|
| 617 |
+
return lhs * rhs
|
| 618 |
+
|
| 619 |
+
if isinstance(rhs, LinearOperator):
|
| 620 |
+
return LinearOperator((lhs.shape[0], rhs.shape[1]), lambda x : lhs.dot(rhs.dot(x)))
|
| 621 |
+
|
| 622 |
+
if isinstance(lhs, LinearOperator):
|
| 623 |
+
if sp.issparse(rhs):
|
| 624 |
+
return LinearOperator((lhs.shape[0], rhs.shape[1]), lambda x : lhs.dot(rhs.dot(x)))
|
| 625 |
+
else:
|
| 626 |
+
# TODO: ?????????????
|
| 627 |
+
# return lhs.matmat(rhs)
|
| 628 |
+
return lhs.dot(rhs)
|
| 629 |
+
|
| 630 |
+
# TODO: Figure out how/whether to do this.
|
| 631 |
+
tm_maybe_sparse = timer()
|
| 632 |
+
lhs, rhs = convert_inputs_to_sparse_if_necessary(lhs, rhs)
|
| 633 |
+
if tm_maybe_sparse() > 0.1:
|
| 634 |
+
pif('convert_inputs_to_sparse_if_necessary in {}sec'.format(tm_maybe_sparse()))
|
| 635 |
+
|
| 636 |
+
if not sp.issparse(lhs) and sp.issparse(rhs):
|
| 637 |
+
return rhs.T.dot(lhs.T).T
|
| 638 |
+
return lhs.dot(rhs)
|
| 639 |
+
except Exception as e:
|
| 640 |
+
import sys, traceback
|
| 641 |
+
traceback.print_exc(file=sys.stdout)
|
| 642 |
+
if DEBUG:
|
| 643 |
+
import pdb; pdb.post_mortem()
|
| 644 |
+
else:
|
| 645 |
+
raise
|
| 646 |
+
|
| 647 |
+
def lmult_wrt(self, lhs, wrt):
|
| 648 |
+
if lhs is None:
|
| 649 |
+
return None
|
| 650 |
+
|
| 651 |
+
self._call_on_changed()
|
| 652 |
+
|
| 653 |
+
drs = []
|
| 654 |
+
|
| 655 |
+
direct_dr = self._compute_dr_wrt_sliced(wrt)
|
| 656 |
+
|
| 657 |
+
if direct_dr != None:
|
| 658 |
+
drs.append(self._superdot(lhs, direct_dr))
|
| 659 |
+
|
| 660 |
+
for k in set(self.dterms):
|
| 661 |
+
p = self.__dict__[k]
|
| 662 |
+
|
| 663 |
+
if hasattr(p, 'dterms') and p is not wrt and p.is_dr_wrt(wrt):
|
| 664 |
+
if not isinstance(p, Ch):
|
| 665 |
+
print('BROKEN!')
|
| 666 |
+
raise Exception('Broken Should be Ch object')
|
| 667 |
+
|
| 668 |
+
indirect_dr = p.lmult_wrt(self._superdot(lhs, self._compute_dr_wrt_sliced(p)), wrt)
|
| 669 |
+
if indirect_dr is not None:
|
| 670 |
+
drs.append(indirect_dr)
|
| 671 |
+
|
| 672 |
+
if len(drs)==0:
|
| 673 |
+
result = None
|
| 674 |
+
|
| 675 |
+
elif len(drs)==1:
|
| 676 |
+
result = drs[0]
|
| 677 |
+
|
| 678 |
+
else:
|
| 679 |
+
result = reduce(lambda x, y: x+y, drs)
|
| 680 |
+
|
| 681 |
+
return result
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def compute_lop(self, wrt, lhs):
|
| 685 |
+
dr = self._compute_dr_wrt_sliced(wrt)
|
| 686 |
+
if dr is None: return None
|
| 687 |
+
return self._superdot(lhs, dr) if not isinstance(lhs, LinearOperator) else lhs.matmat(dr)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def lop(self, wrt, lhs):
|
| 691 |
+
self._call_on_changed()
|
| 692 |
+
|
| 693 |
+
drs = []
|
| 694 |
+
direct_dr = self.compute_lop(wrt, lhs)
|
| 695 |
+
if direct_dr is not None:
|
| 696 |
+
drs.append(direct_dr)
|
| 697 |
+
|
| 698 |
+
for k in set(self.dterms):
|
| 699 |
+
p = getattr(self, k) # self.__dict__[k]
|
| 700 |
+
if hasattr(p, 'dterms') and p is not wrt: # and p.is_dr_wrt(wrt):
|
| 701 |
+
lhs_for_child = self.compute_lop(p, lhs)
|
| 702 |
+
if lhs_for_child is not None: # Can be None with ChLambda, _result etc
|
| 703 |
+
indirect_dr = p.lop(wrt, lhs_for_child)
|
| 704 |
+
if indirect_dr is not None:
|
| 705 |
+
drs.append(indirect_dr)
|
| 706 |
+
|
| 707 |
+
for k in range(len(drs)):
|
| 708 |
+
if sp.issparse(drs[k]):
|
| 709 |
+
drs[k] = drs[k].todense()
|
| 710 |
+
|
| 711 |
+
if len(drs)==0:
|
| 712 |
+
result = None
|
| 713 |
+
|
| 714 |
+
elif len(drs)==1:
|
| 715 |
+
result = drs[0]
|
| 716 |
+
|
| 717 |
+
else:
|
| 718 |
+
result = reduce(lambda x, y: x+y, drs)
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
return result
|
| 722 |
+
|
| 723 |
+
def compute_rop(self, wrt, rhs):
|
| 724 |
+
dr = self._compute_dr_wrt_sliced(wrt)
|
| 725 |
+
if dr is None: return None
|
| 726 |
+
|
| 727 |
+
return self._superdot(dr, rhs)
|
| 728 |
+
|
| 729 |
+
def dr_wrt(self, wrt, reverse_mode=False, profiler=None):
|
| 730 |
+
tm_dr_wrt = timer()
|
| 731 |
+
self.called_dr_wrt = True
|
| 732 |
+
self._call_on_changed()
|
| 733 |
+
|
| 734 |
+
drs = []
|
| 735 |
+
|
| 736 |
+
if wrt in self._cache['drs']:
|
| 737 |
+
if DEBUG:
|
| 738 |
+
if wrt not in self._cache_info:
|
| 739 |
+
self._cache_info[wrt] = 0
|
| 740 |
+
self._cache_info[wrt] +=1
|
| 741 |
+
self._status = 'cached'
|
| 742 |
+
return self._cache['drs'][wrt]
|
| 743 |
+
|
| 744 |
+
direct_dr = self._compute_dr_wrt_sliced(wrt)
|
| 745 |
+
|
| 746 |
+
if direct_dr is not None:
|
| 747 |
+
drs.append(direct_dr)
|
| 748 |
+
|
| 749 |
+
if DEBUG:
|
| 750 |
+
self._status = 'pending'
|
| 751 |
+
|
| 752 |
+
propnames = set(_props_for(self.__class__))
|
| 753 |
+
for k in set(self.dterms).intersection(propnames.union(set(self.__dict__.keys()))):
|
| 754 |
+
|
| 755 |
+
p = getattr(self, k)
|
| 756 |
+
|
| 757 |
+
if hasattr(p, 'dterms') and p is not wrt:
|
| 758 |
+
|
| 759 |
+
indirect_dr = None
|
| 760 |
+
|
| 761 |
+
if reverse_mode:
|
| 762 |
+
lhs = self._compute_dr_wrt_sliced(p)
|
| 763 |
+
if isinstance(lhs, LinearOperator):
|
| 764 |
+
tm_dr_wrt.pause()
|
| 765 |
+
dr2 = p.dr_wrt(wrt)
|
| 766 |
+
tm_dr_wrt.resume()
|
| 767 |
+
indirect_dr = lhs.matmat(dr2) if dr2 != None else None
|
| 768 |
+
else:
|
| 769 |
+
indirect_dr = p.lmult_wrt(lhs, wrt)
|
| 770 |
+
else: # forward mode
|
| 771 |
+
tm_dr_wrt.pause()
|
| 772 |
+
dr2 = p.dr_wrt(wrt, profiler=profiler)
|
| 773 |
+
tm_dr_wrt.resume()
|
| 774 |
+
if dr2 is not None:
|
| 775 |
+
indirect_dr = self.compute_rop(p, rhs=dr2)
|
| 776 |
+
|
| 777 |
+
if indirect_dr is not None:
|
| 778 |
+
drs.append(indirect_dr)
|
| 779 |
+
|
| 780 |
+
if len(drs)==0:
|
| 781 |
+
result = None
|
| 782 |
+
elif len(drs)==1:
|
| 783 |
+
result = drs[0]
|
| 784 |
+
else:
|
| 785 |
+
# TODO: ????????
|
| 786 |
+
# result = np.sum(x for x in drs)
|
| 787 |
+
if not np.any([isinstance(a, LinearOperator) for a in drs]):
|
| 788 |
+
result = reduce(lambda x, y: x+y, drs)
|
| 789 |
+
else:
|
| 790 |
+
result = LinearOperator(drs[0].shape, lambda x : reduce(lambda a, b: a.dot(x)+b.dot(x),drs))
|
| 791 |
+
|
| 792 |
+
# TODO: figure out how/whether to do this.
|
| 793 |
+
if result is not None and not sp.issparse(result):
|
| 794 |
+
tm_nonzero = timer()
|
| 795 |
+
nonzero = np.count_nonzero(result)
|
| 796 |
+
if tm_nonzero() > 0.1:
|
| 797 |
+
pif('count_nonzero in {}sec'.format(tm_nonzero()))
|
| 798 |
+
if nonzero == 0 or hasattr(result, 'size') and result.size / float(nonzero) >= 10.0:
|
| 799 |
+
tm_convert_to_sparse = timer()
|
| 800 |
+
result = sp.csc_matrix(result)
|
| 801 |
+
import gc
|
| 802 |
+
gc.collect()
|
| 803 |
+
pif('converting result to sparse in {}sec'.format(tm_convert_to_sparse()))
|
| 804 |
+
|
| 805 |
+
if (result is not None) and (not sp.issparse(result)) and (not isinstance(result, LinearOperator)):
|
| 806 |
+
result = np.atleast_2d(result)
|
| 807 |
+
|
| 808 |
+
# When the number of parents is one, it indicates that
|
| 809 |
+
# caching this is probably not useful because not
|
| 810 |
+
# more than one parent will likely ask for this same
|
| 811 |
+
# thing again in the same iteration of an optimization.
|
| 812 |
+
#
|
| 813 |
+
# When the number of parents is zero, this is the top
|
| 814 |
+
# level object and should be cached; when it's > 1
|
| 815 |
+
# cache the combinations of the children.
|
| 816 |
+
#
|
| 817 |
+
# If we *always* filled in the cache, it would require
|
| 818 |
+
# more memory but would occasionally save a little cpu,
|
| 819 |
+
# on average.
|
| 820 |
+
if len(list(self._parents.keys())) != 1:
|
| 821 |
+
self._cache['drs'][wrt] = result
|
| 822 |
+
|
| 823 |
+
if DEBUG:
|
| 824 |
+
self._status = 'done'
|
| 825 |
+
|
| 826 |
+
if getattr(self, '_make_dense', False) and sp.issparse(result):
|
| 827 |
+
result = result.todense()
|
| 828 |
+
if getattr(self, '_make_sparse', False) and not sp.issparse(result):
|
| 829 |
+
result = sp.csc_matrix(result)
|
| 830 |
+
|
| 831 |
+
if tm_dr_wrt() > 0.1:
|
| 832 |
+
pif('dx of {} wrt {} in {}sec, sparse: {}'.format(self.short_name, wrt.short_name, tm_dr_wrt(), sp.issparse(result)))
|
| 833 |
+
|
| 834 |
+
return result
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
def __call__(self, **kwargs):
|
| 838 |
+
self.set(**kwargs)
|
| 839 |
+
return self.r
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
########################################################
|
| 843 |
+
# Visualization
|
| 844 |
+
|
| 845 |
+
@property
|
| 846 |
+
def reset_flag(self):
|
| 847 |
+
"""
|
| 848 |
+
Used as fn in loop_children_do
|
| 849 |
+
"""
|
| 850 |
+
return lambda x: setattr(x, 'called_dr_wrt', False)
|
| 851 |
+
|
| 852 |
+
def loop_children_do(self, fn):
|
| 853 |
+
fn(self)
|
| 854 |
+
for dterm in self.dterms:
|
| 855 |
+
if hasattr(self, dterm):
|
| 856 |
+
dtval = getattr(self, dterm)
|
| 857 |
+
if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
|
| 858 |
+
if hasattr(dtval, 'loop_children_do'):
|
| 859 |
+
dtval.loop_children_do(fn)
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
def show_tree_cache(self, label, current_node=None):
|
| 863 |
+
'''
|
| 864 |
+
Show tree and cache info with color represent _status
|
| 865 |
+
Optionally accpet current_node arg to highlight the current node we are in
|
| 866 |
+
'''
|
| 867 |
+
import os
|
| 868 |
+
import tempfile
|
| 869 |
+
import subprocess
|
| 870 |
+
|
| 871 |
+
assert DEBUG, "Please use dr tree visualization functions in debug mode"
|
| 872 |
+
|
| 873 |
+
cache_path = os.path.abspath('profiles')
|
| 874 |
+
def string_for(self, my_name):
|
| 875 |
+
|
| 876 |
+
color_mapping = {'new' : 'grey', 'pending':'red', 'cached':'yellow', 'done': 'green'}
|
| 877 |
+
if hasattr(self, 'label'):
|
| 878 |
+
my_name = self.label
|
| 879 |
+
my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
|
| 880 |
+
result = []
|
| 881 |
+
if not hasattr(self, 'dterms'):
|
| 882 |
+
return result
|
| 883 |
+
for dterm in self.dterms:
|
| 884 |
+
if hasattr(self, dterm):
|
| 885 |
+
dtval = getattr(self, dterm)
|
| 886 |
+
if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
|
| 887 |
+
child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
|
| 888 |
+
child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
|
| 889 |
+
src = 'aaa%d' % (id(self))
|
| 890 |
+
dst = 'aaa%d' % (id(dtval))
|
| 891 |
+
|
| 892 |
+
s = ''
|
| 893 |
+
color = color_mapping[dtval._status] if hasattr(dtval, '_status') else 'grey'
|
| 894 |
+
if dtval == current_node:
|
| 895 |
+
color = 'blue'
|
| 896 |
+
if isinstance(dtval, Concatenate) and len(dtval.dr_cached) > 0:
|
| 897 |
+
s = 'dr_cached\n'
|
| 898 |
+
for k, v in dtval.dr_cached.items():
|
| 899 |
+
if v is not None:
|
| 900 |
+
issparse = sp.issparse(v)
|
| 901 |
+
size = v.size
|
| 902 |
+
if issparse:
|
| 903 |
+
size = v.shape[0] * v.shape[1]
|
| 904 |
+
nonzero = len(v.data)
|
| 905 |
+
else:
|
| 906 |
+
nonzero = np.count_nonzero(v)
|
| 907 |
+
s += '\nsparse: %s\nsize: %d\nnonzero: %d\n' % (issparse, size, nonzero)
|
| 908 |
+
# if dtval.called_dr_wrt:
|
| 909 |
+
# # dtval.called_dr_wrt = False
|
| 910 |
+
# color = 'brown3'
|
| 911 |
+
# else:
|
| 912 |
+
# color = 'azure1'
|
| 913 |
+
elif len(dtval._cache['drs']) > 0:
|
| 914 |
+
s = '_cache\n'
|
| 915 |
+
|
| 916 |
+
for k, v in dtval._cache['drs'].items():
|
| 917 |
+
if v is not None:
|
| 918 |
+
issparse = sp.issparse(v)
|
| 919 |
+
size = v.size
|
| 920 |
+
if issparse:
|
| 921 |
+
size = v.shape[0] * v.shape[1]
|
| 922 |
+
nonzero = len(v.data)
|
| 923 |
+
else:
|
| 924 |
+
nonzero = np.count_nonzero(v)
|
| 925 |
+
|
| 926 |
+
s += '\nsparse: %s\nsize: %d\nnonzero: %d\n' % (issparse, size, nonzero)
|
| 927 |
+
if hasattr(dtval, '_cache_info'):
|
| 928 |
+
s += '\ncache hit:%s\n' % dtval._cache_info[k]
|
| 929 |
+
# if hasattr(dtval,'called_dr_wrt') and dtval.called_dr_wrt:
|
| 930 |
+
# # dtval.called_dr_wrt = False
|
| 931 |
+
# color = 'brown3'
|
| 932 |
+
# else:
|
| 933 |
+
# color = 'azure1'
|
| 934 |
+
result += ['%s -> %s;' % (src, dst)]
|
| 935 |
+
# Do not overwrite src
|
| 936 |
+
#result += ['%s [label="%s"];' % (src, my_name)]
|
| 937 |
+
result += ['%s [label="%s\n%s\n", color=%s, style=filled];' %
|
| 938 |
+
(dst, child_label, s, color)]
|
| 939 |
+
result += string_for(getattr(self, dterm), dterm)
|
| 940 |
+
return result
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root'))))
|
| 944 |
+
dot_file_name = os.path.join(cache_path, label)
|
| 945 |
+
png_file_name = os.path.join(cache_path, label+'.png')
|
| 946 |
+
with open(dot_file_name, 'w') as dot_file:
|
| 947 |
+
with open(png_file_name, 'w') as png_file:
|
| 948 |
+
dot_file.write(dot_file_contents)
|
| 949 |
+
dot_file.flush()
|
| 950 |
+
|
| 951 |
+
png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 952 |
+
subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
|
| 953 |
+
|
| 954 |
+
import webbrowser
|
| 955 |
+
webbrowser.open('file://' + png_file.name)
|
| 956 |
+
|
| 957 |
+
self.loop_children_do(self.reset_flag)
|
| 958 |
+
|
| 959 |
+
def show_tree_wrt(self, wrt):
|
| 960 |
+
import tempfile
|
| 961 |
+
import subprocess
|
| 962 |
+
|
| 963 |
+
assert DEBUG, "Please use dr tree visualization functions in debug mode"
|
| 964 |
+
|
| 965 |
+
def string_for(self, my_name, wrt):
|
| 966 |
+
if hasattr(self, 'label'):
|
| 967 |
+
my_name = self.label
|
| 968 |
+
my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
|
| 969 |
+
result = []
|
| 970 |
+
if not hasattr(self, 'dterms'):
|
| 971 |
+
return result
|
| 972 |
+
for dterm in self.dterms:
|
| 973 |
+
if hasattr(self, dterm):
|
| 974 |
+
dtval = getattr(self, dterm)
|
| 975 |
+
if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
|
| 976 |
+
child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
|
| 977 |
+
child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
|
| 978 |
+
src = 'aaa%d' % (id(self))
|
| 979 |
+
dst = 'aaa%d' % (id(dtval))
|
| 980 |
+
result += ['%s -> %s;' % (src, dst)]
|
| 981 |
+
result += ['%s [label="%s"];' % (src, my_name)]
|
| 982 |
+
if wrt in dtval._cache['drs'] and dtval._cache['drs'][wrt] is not None:
|
| 983 |
+
issparse = sp.issparse(dtval._cache['drs'][wrt])
|
| 984 |
+
size = dtval._cache['drs'][wrt].size
|
| 985 |
+
nonzero = np.count_nonzero(dtval._cache['drs'][wrt])
|
| 986 |
+
result += ['%s [label="%s\n is_sparse: %s\nsize: %d\nnonzero: %d"];' %
|
| 987 |
+
(dst, child_label, issparse, size,
|
| 988 |
+
nonzero)]
|
| 989 |
+
else:
|
| 990 |
+
result += ['%s [label="%s"];' % (dst, child_label)]
|
| 991 |
+
result += string_for(getattr(self, dterm), dterm, wrt)
|
| 992 |
+
return result
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root', wrt))))
|
| 996 |
+
dot_file = tempfile.NamedTemporaryFile()
|
| 997 |
+
dot_file.write(dot_file_contents)
|
| 998 |
+
dot_file.flush()
|
| 999 |
+
png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 1000 |
+
subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
|
| 1001 |
+
import webbrowser
|
| 1002 |
+
webbrowser.open('file://' + png_file.name)
|
| 1003 |
+
|
| 1004 |
+
def show_tree(self, cachelim=np.inf):
|
| 1005 |
+
"""Cachelim is in Mb. For any cached jacobians above cachelim, they are also added to the graph. """
|
| 1006 |
+
import tempfile
|
| 1007 |
+
import subprocess
|
| 1008 |
+
|
| 1009 |
+
assert DEBUG, "Please use dr tree visualization functions in debug mode"
|
| 1010 |
+
|
| 1011 |
+
def string_for(self, my_name):
|
| 1012 |
+
if hasattr(self, 'label'):
|
| 1013 |
+
my_name = self.label
|
| 1014 |
+
my_name = '%s (%s)' % (my_name, str(self.__class__.__name__))
|
| 1015 |
+
result = []
|
| 1016 |
+
if not hasattr(self, 'dterms'):
|
| 1017 |
+
return result
|
| 1018 |
+
for dterm in self.dterms:
|
| 1019 |
+
if hasattr(self, dterm):
|
| 1020 |
+
dtval = getattr(self, dterm)
|
| 1021 |
+
if hasattr(dtval, 'dterms') or hasattr(dtval, 'terms'):
|
| 1022 |
+
child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
|
| 1023 |
+
child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
|
| 1024 |
+
src = 'aaa%d' % (id(self))
|
| 1025 |
+
dst = 'aaa%d' % (id(dtval))
|
| 1026 |
+
result += ['%s -> %s;' % (src, dst)]
|
| 1027 |
+
result += ['%s [label="%s"];' % (src, my_name)]
|
| 1028 |
+
result += ['%s [label="%s"];' % (dst, child_label)]
|
| 1029 |
+
result += string_for(getattr(self, dterm), dterm)
|
| 1030 |
+
|
| 1031 |
+
if cachelim != np.inf and hasattr(self, '_cache') and 'drs' in self._cache:
|
| 1032 |
+
from six.moves import cPickle as pickle
|
| 1033 |
+
for dtval, jac in list(self._cache['drs'].items()):
|
| 1034 |
+
# child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
|
| 1035 |
+
# child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
|
| 1036 |
+
src = 'aaa%d' % (id(self))
|
| 1037 |
+
dst = 'aaa%d' % (id(dtval))
|
| 1038 |
+
try:
|
| 1039 |
+
sz = sys.getsizeof(pickle.dumps(jac, -1))
|
| 1040 |
+
except: # some are functions
|
| 1041 |
+
sz = 0
|
| 1042 |
+
# colorattr = "#%02x%02x%02x" % (szpct*255, 0, (1-szpct)*255)
|
| 1043 |
+
# print colorattr
|
| 1044 |
+
if sz > (cachelim * 1024 * 1024):
|
| 1045 |
+
result += ['%s -> %s [style=dotted,color="<<<%d>>>"];' % (src, dst, sz)]
|
| 1046 |
+
#
|
| 1047 |
+
# result += ['%s -> %s [style=dotted];' % (src, dst)]
|
| 1048 |
+
# result += ['%s [label="%s"];' % (src, my_name)]
|
| 1049 |
+
# result += ['%s [label="%s"];' % (dst, child_label)]
|
| 1050 |
+
# result += string_for(getattr(self, dterm), dterm)
|
| 1051 |
+
|
| 1052 |
+
return result
|
| 1053 |
+
|
| 1054 |
+
dot_file_contents = 'digraph G {\n%s\n}' % '\n'.join(list(set(string_for(self, 'root'))))
|
| 1055 |
+
if cachelim != np.inf:
|
| 1056 |
+
import re
|
| 1057 |
+
strs = re.findall(r'<<<(\d+)>>>', dot_file_contents, re.DOTALL)
|
| 1058 |
+
if len(strs) > 0:
|
| 1059 |
+
the_max = np.max(np.array([int(d) for d in strs]))
|
| 1060 |
+
for s in strs:
|
| 1061 |
+
szpct = float(s)/the_max
|
| 1062 |
+
sz = float(s)
|
| 1063 |
+
unit = 'b'
|
| 1064 |
+
if sz > 1024.:
|
| 1065 |
+
sz /= 1024
|
| 1066 |
+
unit = 'K'
|
| 1067 |
+
if sz > 1024.:
|
| 1068 |
+
sz /= 1024
|
| 1069 |
+
unit = 'M'
|
| 1070 |
+
if sz > 1024.:
|
| 1071 |
+
sz /= 1024
|
| 1072 |
+
unit = 'G'
|
| 1073 |
+
if sz > 1024.:
|
| 1074 |
+
sz /= 1024
|
| 1075 |
+
unit = 'T'
|
| 1076 |
+
|
| 1077 |
+
dot_file_contents = re.sub('<<<%s>>>' % s, '#%02x%02x%02x",label="%d%s' % (szpct*255, 0, (1-szpct)*255, sz, unit), dot_file_contents)
|
| 1078 |
+
|
| 1079 |
+
dot_file = tempfile.NamedTemporaryFile()
|
| 1080 |
+
dot_file.write(dot_file_contents)
|
| 1081 |
+
dot_file.flush()
|
| 1082 |
+
png_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 1083 |
+
subprocess.call(['dot', '-Tpng', '-o', png_file.name, dot_file.name])
|
| 1084 |
+
import webbrowser
|
| 1085 |
+
webbrowser.open('file://' + png_file.name)
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
def tree_iterator(self, visited=None, path=None):
|
| 1089 |
+
'''
|
| 1090 |
+
Generator function that traverse the dr tree start from this node (self).
|
| 1091 |
+
'''
|
| 1092 |
+
if visited is None:
|
| 1093 |
+
visited = set()
|
| 1094 |
+
|
| 1095 |
+
if self not in visited:
|
| 1096 |
+
if path and isinstance(path, list):
|
| 1097 |
+
path.append(self)
|
| 1098 |
+
|
| 1099 |
+
visited.add(self)
|
| 1100 |
+
yield self
|
| 1101 |
+
|
| 1102 |
+
if not hasattr(self, 'dterms'):
|
| 1103 |
+
yield
|
| 1104 |
+
|
| 1105 |
+
for dterm in self.dterms:
|
| 1106 |
+
if hasattr(self, dterm):
|
| 1107 |
+
child = getattr(self, dterm)
|
| 1108 |
+
if hasattr(child, 'dterms') or hasattr(child, 'terms'):
|
| 1109 |
+
for node in child.tree_iterator(visited):
|
| 1110 |
+
yield node
|
| 1111 |
+
|
| 1112 |
+
def floor(self):
|
| 1113 |
+
return floor(self)
|
| 1114 |
+
|
| 1115 |
+
def ceil(self):
|
| 1116 |
+
return ceil(self)
|
| 1117 |
+
|
| 1118 |
+
def dot(self, other):
|
| 1119 |
+
return dot(self, other)
|
| 1120 |
+
|
| 1121 |
+
def cumsum(self, axis=None):
|
| 1122 |
+
return cumsum(a=self, axis=axis)
|
| 1123 |
+
|
| 1124 |
+
def min(self, axis=None):
|
| 1125 |
+
return amin(a=self, axis=axis)
|
| 1126 |
+
|
| 1127 |
+
def max(self, axis=None):
|
| 1128 |
+
return amax(a=self, axis=axis)
|
| 1129 |
+
|
| 1130 |
+
########################################################
|
| 1131 |
+
# Operator overloads
|
| 1132 |
+
|
| 1133 |
+
def __pos__(self): return self
|
| 1134 |
+
def __neg__(self): return negative(self)
|
| 1135 |
+
|
| 1136 |
+
def __add__ (self, other): return add(a=self, b=other)
|
| 1137 |
+
def __radd__(self, other): return add(a=other, b=self)
|
| 1138 |
+
|
| 1139 |
+
def __sub__ (self, other): return subtract(a=self, b=other)
|
| 1140 |
+
def __rsub__(self, other): return subtract(a=other, b=self)
|
| 1141 |
+
|
| 1142 |
+
def __mul__ (self, other): return multiply(a=self, b=other)
|
| 1143 |
+
def __rmul__(self, other): return multiply(a=other, b=self)
|
| 1144 |
+
|
| 1145 |
+
def __div__ (self, other): return divide(x1=self, x2=other)
|
| 1146 |
+
def __truediv__ (self, other): return divide(x1=self, x2=other)
|
| 1147 |
+
def __rdiv__(self, other): return divide(x1=other, x2=self)
|
| 1148 |
+
|
| 1149 |
+
def __pow__ (self, other): return power(x=self, pow=other)
|
| 1150 |
+
def __rpow__(self, other): return power(x=other, pow=self)
|
| 1151 |
+
|
| 1152 |
+
def __rand__(self, other): return self.__and__(other)
|
| 1153 |
+
|
| 1154 |
+
def __abs__ (self): return abs(self)
|
| 1155 |
+
|
| 1156 |
+
def __gt__(self, other): return greater(self, other)
|
| 1157 |
+
def __ge__(self, other): return greater_equal(self, other)
|
| 1158 |
+
|
| 1159 |
+
def __lt__(self, other): return less(self, other)
|
| 1160 |
+
def __le__(self, other): return less_equal(self, other)
|
| 1161 |
+
|
| 1162 |
+
def __ne__(self, other): return not_equal(self, other)
|
| 1163 |
+
|
| 1164 |
+
# not added yet because of weak key dict conflicts
|
| 1165 |
+
#def __eq__(self, other): return equal(self, other)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
Ch._reserved_kw = set(Ch.__dict__.keys())
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
class MatVecMult(Ch):
|
| 1172 |
+
terms = 'mtx'
|
| 1173 |
+
dterms = 'vec'
|
| 1174 |
+
def compute_r(self):
|
| 1175 |
+
result = self.mtx.dot(col(self.vec.r.ravel())).ravel()
|
| 1176 |
+
if len(self.vec.r.shape) > 1 and self.vec.r.shape[1] > 1:
|
| 1177 |
+
result = result.reshape((-1,self.vec.r.shape[1]))
|
| 1178 |
+
return result
|
| 1179 |
+
|
| 1180 |
+
def compute_dr_wrt(self, wrt):
|
| 1181 |
+
if wrt is self.vec:
|
| 1182 |
+
return sp.csc_matrix(self.mtx)
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
#def depends_on(*dependencies):
|
| 1186 |
+
# def _depends_on(func):
|
| 1187 |
+
# @wraps(func)
|
| 1188 |
+
# def with_caching(self, *args, **kwargs):
|
| 1189 |
+
# return func(self, *args, **kwargs)
|
| 1190 |
+
# return property(with_caching)
|
| 1191 |
+
# return _depends_on
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
def depends_on(*dependencies):
|
| 1195 |
+
deps = set()
|
| 1196 |
+
for dep in dependencies:
|
| 1197 |
+
if isinstance(dep, str):
|
| 1198 |
+
deps.add(dep)
|
| 1199 |
+
else:
|
| 1200 |
+
[deps.add(d) for d in dep]
|
| 1201 |
+
|
| 1202 |
+
def _depends_on(func):
|
| 1203 |
+
want_out = 'out' in inspect.getfullargspec(func).args
|
| 1204 |
+
|
| 1205 |
+
@wraps(func)
|
| 1206 |
+
def with_caching(self, *args, **kwargs):
|
| 1207 |
+
func_name = func.__name__
|
| 1208 |
+
sdf = self._depends_on_deps[func_name]
|
| 1209 |
+
if sdf['out_of_date'] == True:
|
| 1210 |
+
#tm = time.time()
|
| 1211 |
+
if want_out:
|
| 1212 |
+
kwargs['out'] = sdf['value']
|
| 1213 |
+
sdf['value'] = func(self, *args, **kwargs)
|
| 1214 |
+
sdf['out_of_date'] = False
|
| 1215 |
+
#print 'recomputed %s in %.2e' % (func_name, time.time() - tm)
|
| 1216 |
+
return sdf['value']
|
| 1217 |
+
with_caching.deps = deps # set(dependencies)
|
| 1218 |
+
result = property(with_caching)
|
| 1219 |
+
return result
|
| 1220 |
+
return _depends_on
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
|
| 1224 |
+
class ChHandle(Ch):
|
| 1225 |
+
dterms = ('x',)
|
| 1226 |
+
|
| 1227 |
+
def compute_r(self):
|
| 1228 |
+
assert(self.x is not self)
|
| 1229 |
+
return self.x.r
|
| 1230 |
+
|
| 1231 |
+
def compute_dr_wrt(self, wrt):
|
| 1232 |
+
if wrt is self.x:
|
| 1233 |
+
return 1
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
class ChLambda(Ch):
|
| 1237 |
+
terms = ['lmb', 'initial_args']
|
| 1238 |
+
dterms = []
|
| 1239 |
+
term_order = ['lmb', 'initial_args']
|
| 1240 |
+
|
| 1241 |
+
def on_changed(self, which):
|
| 1242 |
+
for argname in set(which).intersection(set(self.args.keys())):
|
| 1243 |
+
self.args[argname].x = getattr(self, argname)
|
| 1244 |
+
|
| 1245 |
+
def __init__(self, lmb, initial_args=None):
|
| 1246 |
+
argspec_args = inspect.getfullargspec(lmb).args
|
| 1247 |
+
args = {argname: ChHandle(x=Ch(idx)) for idx, argname in enumerate(argspec_args)}
|
| 1248 |
+
if initial_args is not None:
|
| 1249 |
+
for initial_arg in initial_args:
|
| 1250 |
+
if initial_arg in args:
|
| 1251 |
+
args[initial_arg].x = initial_args[initial_arg]
|
| 1252 |
+
result = lmb(**args)
|
| 1253 |
+
for argname, arg in list(args.items()):
|
| 1254 |
+
if result.is_dr_wrt(arg.x):
|
| 1255 |
+
self.add_dterm(argname, arg.x)
|
| 1256 |
+
else:
|
| 1257 |
+
self.terms.append(argname)
|
| 1258 |
+
setattr(self, argname, arg.x)
|
| 1259 |
+
self.args = args
|
| 1260 |
+
self.add_dterm('_result', result)
|
| 1261 |
+
|
| 1262 |
+
def __getstate__(self):
|
| 1263 |
+
# Have to get rid of lambda for serialization
|
| 1264 |
+
if hasattr(self, 'lmb'):
|
| 1265 |
+
self.lmb = None
|
| 1266 |
+
return super(self.__class__, self).__getstate__()
|
| 1267 |
+
|
| 1268 |
+
|
| 1269 |
+
def compute_r(self):
|
| 1270 |
+
return self._result.r
|
| 1271 |
+
|
| 1272 |
+
def compute_dr_wrt(self, wrt):
|
| 1273 |
+
if wrt is self._result:
|
| 1274 |
+
return 1
|
| 1275 |
+
|
| 1276 |
+
# ChGroup is similar to ChLambda in that it's designed to expose the "internal"
|
| 1277 |
+
# inputs of result but, unlike ChLambda, result is kept internal and called when
|
| 1278 |
+
# compute_r and compute_dr_wrt is called to compute the relevant Jacobians.
|
| 1279 |
+
# This provides a way of effectively applying the chain rule in a different order.
|
| 1280 |
+
class ChGroup(Ch):
|
| 1281 |
+
terms = ['result', 'args']
|
| 1282 |
+
dterms = []
|
| 1283 |
+
term_order = ['result', 'args']
|
| 1284 |
+
|
| 1285 |
+
def on_changed(self, which):
|
| 1286 |
+
for argname in set(which).intersection(set(self.args.keys())):
|
| 1287 |
+
if not self.args[argname].x is getattr(self, argname) :
|
| 1288 |
+
self.args[argname].x = getattr(self, argname)
|
| 1289 |
+
|
| 1290 |
+
# right now the entries in args have to refer to terms/dterms of result,
|
| 1291 |
+
# it would be better if they could be "internal" as well, but for now the idea
|
| 1292 |
+
# is that result may itself be a ChLambda.
|
| 1293 |
+
def __init__(self, result, args):
|
| 1294 |
+
self.args = { argname: ChHandle(x=arg) for argname, arg in list(args.items()) }
|
| 1295 |
+
for argname, arg in list(self.args.items()):
|
| 1296 |
+
setattr(result, argname, arg)
|
| 1297 |
+
if result.is_dr_wrt(arg.x):
|
| 1298 |
+
self.add_dterm(argname, arg.x)
|
| 1299 |
+
else:
|
| 1300 |
+
self.terms.append(argname)
|
| 1301 |
+
setattr(self, argname, arg.x)
|
| 1302 |
+
self._result = result
|
| 1303 |
+
|
| 1304 |
+
def compute_r(self):
|
| 1305 |
+
return self._result.r
|
| 1306 |
+
|
| 1307 |
+
def compute_dr_wrt(self, wrt):
|
| 1308 |
+
return self._result.dr_wrt(wrt)
|
| 1309 |
+
|
| 1310 |
+
from .ch_ops import *
|
| 1311 |
+
from .ch_ops import __all__ as all_ch_ops
|
| 1312 |
+
__all__ += all_ch_ops
|
| 1313 |
+
|
| 1314 |
+
from .reordering import *
|
| 1315 |
+
from .reordering import Permute
|
| 1316 |
+
from .reordering import __all__ as all_reordering
|
| 1317 |
+
__all__ += all_reordering
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
from . import linalg
|
| 1321 |
+
from . import ch_random as random
|
| 1322 |
+
__all__ += ['linalg', 'random']
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
+
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
class tst(Ch):
|
| 1329 |
+
dterms = ['a', 'b', 'c']
|
| 1330 |
+
|
| 1331 |
+
def compute_r(self):
|
| 1332 |
+
return self.a.r + self.b.r + self.c.r
|
| 1333 |
+
|
| 1334 |
+
def compute_dr_wrt(self, wrt):
|
| 1335 |
+
return 1
|
| 1336 |
+
|
| 1337 |
+
def main():
|
| 1338 |
+
foo = tst
|
| 1339 |
+
|
| 1340 |
+
x10 = Ch(10)
|
| 1341 |
+
x20 = Ch(20)
|
| 1342 |
+
x30 = Ch(30)
|
| 1343 |
+
|
| 1344 |
+
tmp = ChLambda(lambda x, y, z: Ch(1) + Ch(2) * Ch(3) + 4)
|
| 1345 |
+
print(tmp.dr_wrt(tmp.x))
|
| 1346 |
+
import pdb; pdb.set_trace()
|
| 1347 |
+
#a(b(c(d(e(f),g),h)))
|
| 1348 |
+
|
| 1349 |
+
blah = tst(x10, x20, x30)
|
| 1350 |
+
|
| 1351 |
+
print(blah.r)
|
| 1352 |
+
|
| 1353 |
+
|
| 1354 |
+
print(foo)
|
| 1355 |
+
|
| 1356 |
+
import pdb; pdb.set_trace()
|
| 1357 |
+
|
| 1358 |
+
# import unittest
|
| 1359 |
+
# from test_ch import TestCh
|
| 1360 |
+
# suite = unittest.TestLoader().loadTestsFromTestCase(TestCh)
|
| 1361 |
+
# unittest.TextTestRunner(verbosity=2).run(suite)
|
| 1362 |
+
|
| 1363 |
+
|
| 1364 |
+
|
| 1365 |
+
if __name__ == '__main__':
|
| 1366 |
+
main()
|
| 1367 |
+
|
vendor/chumpy/chumpy/ch_ops.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Numpy functions
|
| 10 |
+
__all__ = ['array', 'amax','amin', 'max', 'min', 'maximum','minimum','nanmax','nanmin',
|
| 11 |
+
'sum', 'exp', 'log', 'mean','std', 'var',
|
| 12 |
+
'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
|
| 13 |
+
'sqrt', 'square', 'absolute', 'abs', 'clip',
|
| 14 |
+
'power',
|
| 15 |
+
'add', 'divide', 'multiply', 'negative', 'subtract', 'reciprocal',
|
| 16 |
+
'nan_to_num',
|
| 17 |
+
'dot', 'cumsum',
|
| 18 |
+
'floor', 'ceil',
|
| 19 |
+
'greater', 'greater_equal', 'less', 'less_equal', 'equal', 'not_equal',
|
| 20 |
+
'nonzero', 'ascontiguousarray', 'asfarray', 'arange', 'asarray', 'copy',
|
| 21 |
+
'cross',
|
| 22 |
+
'shape', 'sign']
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ += ['SumOfSquares',
|
| 26 |
+
'NanDivide', ]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# These can be wrapped directly as Ch(routine(*args, **kwargs)),
|
| 30 |
+
# so that for example "ch.eye(3)" translates into Ch(np.eye(3))
|
| 31 |
+
numpy_array_creation_routines = [
|
| 32 |
+
'empty','empty_like','eye','identity','ones','ones_like','zeros','zeros_like',
|
| 33 |
+
'array',
|
| 34 |
+
'arange','linspace','logspace','meshgrid','mgrid','ogrid',
|
| 35 |
+
'fromfunction', 'fromiter', 'meshgrid', 'tri'
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
wont_implement = ['asanyarray', 'asmatrix', 'frombuffer', 'copy', 'fromfile', 'fromstring', 'loadtxt', 'copyto', 'asmatrix', 'asfortranarray', 'asscalar', 'require']
|
| 40 |
+
not_yet_implemented = ['tril', 'triu', 'vander']
|
| 41 |
+
|
| 42 |
+
__all__ += not_yet_implemented
|
| 43 |
+
__all__ += wont_implement
|
| 44 |
+
__all__ += numpy_array_creation_routines
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
from .ch import Ch
|
| 48 |
+
import six
|
| 49 |
+
import numpy as np
|
| 50 |
+
import warnings
|
| 51 |
+
from six.moves import cPickle as pickle
|
| 52 |
+
import scipy.sparse as sp
|
| 53 |
+
from .utils import row, col
|
| 54 |
+
from copy import copy as copy_copy
|
| 55 |
+
from functools import reduce
|
| 56 |
+
|
| 57 |
+
__all__ += ['pi', 'set_printoptions']
|
| 58 |
+
pi = np.pi
|
| 59 |
+
set_printoptions = np.set_printoptions
|
| 60 |
+
arange = np.arange
|
| 61 |
+
|
| 62 |
+
for rtn in ['argmax', 'nanargmax', 'argmin', 'nanargmin']:
|
| 63 |
+
exec('def %s(a, axis=None) : return np.%s(a.r, axis) if hasattr(a, "compute_r") else np.%s(a, axis)' % (rtn, rtn, rtn))
|
| 64 |
+
__all__ += [rtn]
|
| 65 |
+
|
| 66 |
+
for rtn in ['argwhere', 'nonzero', 'flatnonzero']:
|
| 67 |
+
exec('def %s(a) : return np.%s(a.r) if hasattr(a, "compute_r") else np.%s(a)' % (rtn, rtn, rtn))
|
| 68 |
+
__all__ += [rtn]
|
| 69 |
+
|
| 70 |
+
for rtn in numpy_array_creation_routines:
|
| 71 |
+
exec('def %s(*args, **kwargs) : return Ch(np.%s(*args, **kwargs))' % (rtn, rtn))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class WontImplement(Exception):
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
for rtn in wont_implement:
|
| 78 |
+
exec('def %s(*args, **kwargs) : raise WontImplement' % (rtn))
|
| 79 |
+
|
| 80 |
+
for rtn in not_yet_implemented:
|
| 81 |
+
exec('def %s(*args, **kwargs) : raise NotImplementedError' % (rtn))
|
| 82 |
+
|
| 83 |
+
def asarray(a, dtype=None, order=None):
|
| 84 |
+
assert(dtype is None or dtype is np.float64)
|
| 85 |
+
assert(order == 'C' or order is None)
|
| 86 |
+
if hasattr(a, 'dterms'):
|
| 87 |
+
return a
|
| 88 |
+
return Ch(np.asarray(a, dtype, order))
|
| 89 |
+
|
| 90 |
+
# Everythign is always c-contiguous
|
| 91 |
+
def ascontiguousarray(a, dtype=None): return a
|
| 92 |
+
|
| 93 |
+
# Everything is always float
|
| 94 |
+
asfarray = ascontiguousarray
|
| 95 |
+
|
| 96 |
+
def copy(self):
|
| 97 |
+
return pickle.loads(pickle.dumps(self))
|
| 98 |
+
|
| 99 |
+
def asfortranarray(a, dtype=None): raise WontImplement
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Simpleton(Ch):
|
| 103 |
+
dterms = 'x'
|
| 104 |
+
def compute_dr_wrt(self, wrt):
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
class floor(Simpleton):
|
| 108 |
+
def compute_r(self): return np.floor(self.x.r)
|
| 109 |
+
|
| 110 |
+
class ceil(Simpleton):
|
| 111 |
+
def compute_r(self): return np.ceil(self.x.r)
|
| 112 |
+
|
| 113 |
+
class sign(Simpleton):
|
| 114 |
+
def compute_r(self): return np.sign(self.x.r)
|
| 115 |
+
|
| 116 |
+
class Cross(Ch):
|
| 117 |
+
dterms = 'a', 'b'
|
| 118 |
+
terms = 'axisa', 'axisb', 'axisc', 'axis'
|
| 119 |
+
term_order = 'a', 'b', 'axisa', 'axisb', 'axisc', 'axis'
|
| 120 |
+
|
| 121 |
+
def compute_r(self):
|
| 122 |
+
return np.cross(self.a.r, self.b.r, self.axisa, self.axisb, self.axisc, self.axis)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _load_crossprod_cache(self, h, w):
|
| 126 |
+
if not hasattr(self, '_w'):
|
| 127 |
+
self._w = 0
|
| 128 |
+
self._h = 0
|
| 129 |
+
|
| 130 |
+
if h!=self._h or w!=self._w:
|
| 131 |
+
sz = h*w
|
| 132 |
+
rng = np.arange(sz)
|
| 133 |
+
self._JS = np.repeat(rng.reshape((-1,w)), w, axis=0).ravel()
|
| 134 |
+
self._IS = np.repeat(rng, w)
|
| 135 |
+
self._tiled_identity = np.tile(np.eye(w), (h, 1))
|
| 136 |
+
self._h = h
|
| 137 |
+
self._w = w
|
| 138 |
+
|
| 139 |
+
return self._tiled_identity, self._IS, self._JS,
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# Could be at least 2x faster, with some work
|
| 144 |
+
def compute_dr_wrt(self, wrt):
|
| 145 |
+
if wrt is not self.a and wrt is not self.b:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
sz = self.a.size
|
| 149 |
+
h, w = self.a.shape
|
| 150 |
+
tiled_identity, IS, JS = self._load_crossprod_cache(h, w)
|
| 151 |
+
|
| 152 |
+
#import time
|
| 153 |
+
#tm = time.time()
|
| 154 |
+
if wrt is self.a:
|
| 155 |
+
rp = np.repeat(-self.b.r, w, axis=0)
|
| 156 |
+
result = np.cross(
|
| 157 |
+
tiled_identity,
|
| 158 |
+
rp,
|
| 159 |
+
self.axisa,
|
| 160 |
+
self.axisb,
|
| 161 |
+
self.axisc,
|
| 162 |
+
self.axis)
|
| 163 |
+
|
| 164 |
+
elif wrt is self.b:
|
| 165 |
+
result = np.cross(
|
| 166 |
+
np.repeat(-self.a.r, w, axis=0),
|
| 167 |
+
tiled_identity,
|
| 168 |
+
self.axisa,
|
| 169 |
+
self.axisb,
|
| 170 |
+
self.axisc,
|
| 171 |
+
self.axis)
|
| 172 |
+
|
| 173 |
+
# rng = np.arange(sz)
|
| 174 |
+
# JS = np.repeat(rng.reshape((-1,w)), w, axis=0).ravel()
|
| 175 |
+
# IS = np.repeat(rng, w)
|
| 176 |
+
data = result.ravel()
|
| 177 |
+
result = sp.csc_matrix((data, (IS,JS)), shape=(self.size, wrt.size))
|
| 178 |
+
#import pdb; pdb.set_trace()
|
| 179 |
+
#print 'B TOOK %es' % (time.time() -tm )
|
| 180 |
+
return result
|
| 181 |
+
|
| 182 |
+
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
|
| 183 |
+
return Cross(a, b, axisa, axisb, axisc, axis)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class cumsum(Ch):
|
| 189 |
+
dterms = 'a'
|
| 190 |
+
terms = 'axis'
|
| 191 |
+
term_order = 'a', 'axis'
|
| 192 |
+
|
| 193 |
+
def on_changed(self, which):
|
| 194 |
+
if not hasattr(self, 'axis'):
|
| 195 |
+
self.axis = None
|
| 196 |
+
|
| 197 |
+
def compute_r(self):
|
| 198 |
+
return np.cumsum(self.a.r, axis=self.axis)
|
| 199 |
+
|
| 200 |
+
def compute_dr_wrt(self, wrt):
|
| 201 |
+
if wrt is not self.a:
|
| 202 |
+
return None
|
| 203 |
+
|
| 204 |
+
if self.axis is not None:
|
| 205 |
+
raise NotImplementedError
|
| 206 |
+
|
| 207 |
+
IS = np.tile(row(np.arange(self.a.size)), (self.a.size, 1))
|
| 208 |
+
JS = IS.T
|
| 209 |
+
IS = IS.ravel()
|
| 210 |
+
JS = JS.ravel()
|
| 211 |
+
which = IS >= JS
|
| 212 |
+
IS = IS[which]
|
| 213 |
+
JS = JS[which]
|
| 214 |
+
data = np.ones_like(IS)
|
| 215 |
+
result = sp.csc_matrix((data, (IS, JS)), shape=(self.a.size, self.a.size))
|
| 216 |
+
return result
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class UnaryElemwise(Ch):
|
| 220 |
+
dterms = 'x'
|
| 221 |
+
|
| 222 |
+
def compute_r(self):
|
| 223 |
+
return self._r(self.x.r)
|
| 224 |
+
|
| 225 |
+
def compute_dr_wrt(self, wrt):
|
| 226 |
+
if wrt is self.x:
|
| 227 |
+
result = self._d(self.x.r)
|
| 228 |
+
return sp.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class nan_to_num(UnaryElemwise):
|
| 232 |
+
_r = lambda self, x : np.nan_to_num(x)
|
| 233 |
+
_d = lambda self, x : np.asarray(np.isfinite(x), np.float64)
|
| 234 |
+
|
| 235 |
+
class reciprocal(UnaryElemwise):
|
| 236 |
+
_r = np.reciprocal
|
| 237 |
+
_d = lambda self, x : -np.reciprocal(np.square(x))
|
| 238 |
+
|
| 239 |
+
class square(UnaryElemwise):
|
| 240 |
+
_r = np.square
|
| 241 |
+
_d = lambda self, x : x * 2.
|
| 242 |
+
|
| 243 |
+
def my_power(a, b):
|
| 244 |
+
with warnings.catch_warnings():
|
| 245 |
+
warnings.filterwarnings("ignore",category=RuntimeWarning)
|
| 246 |
+
return np.nan_to_num(np.power(a, b))
|
| 247 |
+
|
| 248 |
+
class sqrt(UnaryElemwise):
|
| 249 |
+
_r = np.sqrt
|
| 250 |
+
_d = lambda self, x : .5 * my_power(x, -0.5)
|
| 251 |
+
|
| 252 |
+
class exp(UnaryElemwise):
|
| 253 |
+
_r = np.exp
|
| 254 |
+
_d = np.exp
|
| 255 |
+
|
| 256 |
+
class log(UnaryElemwise):
|
| 257 |
+
_r = np.log
|
| 258 |
+
_d = np.reciprocal
|
| 259 |
+
|
| 260 |
+
class sin(UnaryElemwise):
|
| 261 |
+
_r = np.sin
|
| 262 |
+
_d = np.cos
|
| 263 |
+
|
| 264 |
+
class arcsin(UnaryElemwise):
|
| 265 |
+
_r = np.arcsin
|
| 266 |
+
_d = lambda self, x : np.reciprocal(np.sqrt(1.-np.square(x)))
|
| 267 |
+
|
| 268 |
+
class cos(UnaryElemwise):
|
| 269 |
+
_r = np.cos
|
| 270 |
+
_d = lambda self, x : -np.sin(x)
|
| 271 |
+
|
| 272 |
+
class arccos(UnaryElemwise):
|
| 273 |
+
_r = np.arccos
|
| 274 |
+
_d = lambda self, x : -np.reciprocal(np.sqrt(1.-np.square(x)))
|
| 275 |
+
|
| 276 |
+
class tan(UnaryElemwise):
|
| 277 |
+
_r = np.tan
|
| 278 |
+
_d = lambda self, x : np.reciprocal(np.cos(x)**2.)
|
| 279 |
+
|
| 280 |
+
class arctan(UnaryElemwise):
|
| 281 |
+
_r = np.arctan
|
| 282 |
+
_d = lambda self, x : np.reciprocal(np.square(x)+1.)
|
| 283 |
+
|
| 284 |
+
class negative(UnaryElemwise):
|
| 285 |
+
_r = np.negative
|
| 286 |
+
_d = lambda self, x : np.negative(np.ones_like(x))
|
| 287 |
+
|
| 288 |
+
class absolute(UnaryElemwise):
|
| 289 |
+
_r = np.abs
|
| 290 |
+
_d = lambda self, x : (x>0)*2-1.
|
| 291 |
+
|
| 292 |
+
abs = absolute
|
| 293 |
+
|
| 294 |
+
class clip(Ch):
|
| 295 |
+
dterms = 'a'
|
| 296 |
+
terms = 'a_min', 'a_max'
|
| 297 |
+
term_order = 'a', 'a_min', 'a_max'
|
| 298 |
+
|
| 299 |
+
def compute_r(self):
|
| 300 |
+
return np.clip(self.a.r, self.a_min, self.a_max)
|
| 301 |
+
|
| 302 |
+
def compute_dr_wrt(self, wrt):
|
| 303 |
+
if wrt is self.a:
|
| 304 |
+
result = np.asarray((self.r != self.a_min) & (self.r != self.a_max), np.float64)
|
| 305 |
+
return sp.diags([result.ravel()], [0]) if len(result)>1 else np.atleast_2d(result)
|
| 306 |
+
|
| 307 |
+
class sum(Ch):
|
| 308 |
+
dterms = 'x',
|
| 309 |
+
terms = 'axis',
|
| 310 |
+
term_order = 'x', 'axis'
|
| 311 |
+
|
| 312 |
+
def on_changed(self, which):
|
| 313 |
+
if not hasattr(self, 'axis'):
|
| 314 |
+
self.axis = None
|
| 315 |
+
if not hasattr(self, 'dr_cache'):
|
| 316 |
+
self.dr_cache = {}
|
| 317 |
+
|
| 318 |
+
def compute_r(self):
|
| 319 |
+
return np.sum(self.x.r, axis=self.axis)
|
| 320 |
+
|
| 321 |
+
def compute_dr_wrt(self, wrt):
|
| 322 |
+
if wrt is not self.x:
|
| 323 |
+
return
|
| 324 |
+
if self.axis == None:
|
| 325 |
+
return row(np.ones((1, len(self.x.r.ravel()))))
|
| 326 |
+
else:
|
| 327 |
+
uid = tuple(list(self.x.shape) + [self.axis])
|
| 328 |
+
if uid not in self.dr_cache:
|
| 329 |
+
idxs_presum = np.arange(self.x.size).reshape(self.x.shape)
|
| 330 |
+
idxs_presum = np.rollaxis(idxs_presum, self.axis, 0)
|
| 331 |
+
idxs_postsum = np.arange(self.r.size).reshape(self.r.shape)
|
| 332 |
+
tp = np.ones(idxs_presum.ndim, dtype=np.uint32)
|
| 333 |
+
tp[0] = idxs_presum.shape[0]
|
| 334 |
+
idxs_postsum = np.tile(idxs_postsum, tp)
|
| 335 |
+
data = np.ones(idxs_postsum.size)
|
| 336 |
+
result = sp.csc_matrix((data, (idxs_postsum.ravel(), idxs_presum.ravel())), (self.r.size, wrt.size))
|
| 337 |
+
self.dr_cache[uid] = result
|
| 338 |
+
return self.dr_cache[uid]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class mean(Ch):
|
| 342 |
+
dterms = 'x',
|
| 343 |
+
terms = 'axis',
|
| 344 |
+
term_order = 'x', 'axis'
|
| 345 |
+
|
| 346 |
+
def on_changed(self, which):
|
| 347 |
+
if not hasattr(self, 'axis'):
|
| 348 |
+
self.axis = None
|
| 349 |
+
if not hasattr(self, 'dr_cache'):
|
| 350 |
+
self.dr_cache = {}
|
| 351 |
+
|
| 352 |
+
def compute_r(self):
|
| 353 |
+
return np.array(np.mean(self.x.r, axis=self.axis))
|
| 354 |
+
|
| 355 |
+
def compute_dr_wrt(self, wrt):
|
| 356 |
+
if wrt is not self.x:
|
| 357 |
+
return
|
| 358 |
+
if self.axis == None:
|
| 359 |
+
return row(np.ones((1, len(self.x.r))))/len(self.x.r)
|
| 360 |
+
else:
|
| 361 |
+
uid = tuple(list(self.x.shape) + [self.axis])
|
| 362 |
+
if uid not in self.dr_cache:
|
| 363 |
+
idxs_presum = np.arange(self.x.size).reshape(self.x.shape)
|
| 364 |
+
idxs_presum = np.rollaxis(idxs_presum, self.axis, 0)
|
| 365 |
+
idxs_postsum = np.arange(self.r.size).reshape(self.r.shape)
|
| 366 |
+
tp = np.ones(idxs_presum.ndim, dtype=np.uint32)
|
| 367 |
+
tp[0] = idxs_presum.shape[0]
|
| 368 |
+
idxs_postsum = np.tile(idxs_postsum, tp)
|
| 369 |
+
data = np.ones(idxs_postsum.size) / self.x.shape[self.axis]
|
| 370 |
+
result = sp.csc_matrix((data, (idxs_postsum.ravel(), idxs_presum.ravel())), (self.r.size, wrt.size))
|
| 371 |
+
self.dr_cache[uid] = result
|
| 372 |
+
return self.dr_cache[uid]
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
|
| 376 |
+
if (dtype != None or out != None or ddof != 0 or keepdims != False):
|
| 377 |
+
raise NotImplementedException('Unimplemented for non-default dtype, out, ddof, and keepdims.')
|
| 378 |
+
return mean(a**2., axis=axis)
|
| 379 |
+
|
| 380 |
+
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
|
| 381 |
+
if (dtype != None or out != None or ddof != 0 or keepdims != False):
|
| 382 |
+
raise NotImplementedException('Unimplemented for non-default dtype, out, ddof, and keepdims.')
|
| 383 |
+
return sqrt(var(a, axis=axis))
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class SumOfSquares(Ch):
|
| 387 |
+
dterms = 'x',
|
| 388 |
+
|
| 389 |
+
def compute_r(self):
|
| 390 |
+
return np.sum(self.x.r.ravel()**2.)
|
| 391 |
+
|
| 392 |
+
def compute_dr_wrt(self, wrt):
|
| 393 |
+
if wrt is self.x:
|
| 394 |
+
return row(self.x.r.ravel()*2.)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class divide (Ch):
|
| 398 |
+
dterms = 'x1', 'x2'
|
| 399 |
+
|
| 400 |
+
def compute_r(self):
|
| 401 |
+
return self.x1.r / self.x2.r
|
| 402 |
+
|
| 403 |
+
def compute_dr_wrt(self, wrt):
|
| 404 |
+
|
| 405 |
+
if (wrt is self.x1) == (wrt is self.x2):
|
| 406 |
+
return None
|
| 407 |
+
|
| 408 |
+
IS, JS, input_sz, output_sz = _broadcast_setup(self.x1, self.x2, wrt)
|
| 409 |
+
|
| 410 |
+
x1r, x2r = self.x1.r, self.x2.r
|
| 411 |
+
if wrt is self.x1:
|
| 412 |
+
data = (np.ones_like(x1r) / x2r).ravel()
|
| 413 |
+
else:
|
| 414 |
+
data = (-x1r / (x2r*x2r)).ravel()
|
| 415 |
+
|
| 416 |
+
return sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.r.size))
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class NanDivide(divide):
|
| 422 |
+
dterms = 'x1', 'x2'
|
| 423 |
+
|
| 424 |
+
def compute_r(self):
|
| 425 |
+
with warnings.catch_warnings():
|
| 426 |
+
warnings.simplefilter("ignore")
|
| 427 |
+
result = super(self.__class__, self).compute_r()
|
| 428 |
+
shape = result.shape
|
| 429 |
+
result = result.ravel()
|
| 430 |
+
result[np.isinf(result)] = 0
|
| 431 |
+
result[np.isnan(result)] = 0
|
| 432 |
+
return result.reshape(shape)
|
| 433 |
+
|
| 434 |
+
def compute_dr_wrt(self, wrt):
|
| 435 |
+
with warnings.catch_warnings():
|
| 436 |
+
warnings.simplefilter("ignore")
|
| 437 |
+
result = super(self.__class__, self).compute_dr_wrt(wrt)
|
| 438 |
+
if result is not None:
|
| 439 |
+
result = result.copy()
|
| 440 |
+
if sp.issparse(result):
|
| 441 |
+
result.data[np.isinf(result.data)] = 0
|
| 442 |
+
result.data[np.isnan(result.data)] = 0
|
| 443 |
+
return result
|
| 444 |
+
else:
|
| 445 |
+
rr = result.ravel()
|
| 446 |
+
rr[np.isnan(rr)] = 0.
|
| 447 |
+
rr[np.isinf(rr)] = 0.
|
| 448 |
+
return result
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def shape(a):
|
| 452 |
+
return a.shape if hasattr(a, 'shape') else np.shape(a)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
_bs_setup_data1 = {}
|
| 456 |
+
_bs_setup_data2 = {}
|
| 457 |
+
def _broadcast_matrix(a, b, wrt, data):
|
| 458 |
+
global _bs_setup_data1, _bs_setup_data2
|
| 459 |
+
|
| 460 |
+
if len(set((a.shape, b.shape))) == 1:
|
| 461 |
+
uid = a.shape
|
| 462 |
+
if uid not in _bs_setup_data1:
|
| 463 |
+
asz = a.size
|
| 464 |
+
IS = np.arange(asz)
|
| 465 |
+
_bs_setup_data1[uid] = sp.csc_matrix((np.empty(asz), (IS, IS)), shape=(asz, asz))
|
| 466 |
+
result = copy_copy(_bs_setup_data1[uid])
|
| 467 |
+
if isinstance(data, np.ndarray):
|
| 468 |
+
result.data = data.ravel()
|
| 469 |
+
else: # assumed scalar
|
| 470 |
+
result.data = np.empty(result.nnz)
|
| 471 |
+
result.data.fill(data)
|
| 472 |
+
else:
|
| 473 |
+
uid = (a.shape, b.shape, wrt is a, wrt is b)
|
| 474 |
+
if uid not in _bs_setup_data2:
|
| 475 |
+
input_sz = wrt.size
|
| 476 |
+
output_sz = np.broadcast(a.r, b.r).size
|
| 477 |
+
a2 = np.arange(a.size).reshape(a.shape) if wrt is a else np.zeros(a.shape)
|
| 478 |
+
b2 = np.arange(b.size).reshape(b.shape) if (wrt is b and wrt is not a) else np.zeros(b.shape)
|
| 479 |
+
IS = np.arange(output_sz)
|
| 480 |
+
JS = np.asarray((np.add(a2,b2)).ravel(), np.uint32)
|
| 481 |
+
|
| 482 |
+
_bs_setup_data2[uid] = sp.csc_matrix((np.arange(IS.size), (IS, JS)), shape=(output_sz, input_sz))
|
| 483 |
+
|
| 484 |
+
result = copy_copy(_bs_setup_data2[uid])
|
| 485 |
+
if isinstance(data, np.ndarray):
|
| 486 |
+
result.data = data[result.data]
|
| 487 |
+
else: # assumed scalar
|
| 488 |
+
result.data = np.empty(result.nnz)
|
| 489 |
+
result.data.fill(data)
|
| 490 |
+
|
| 491 |
+
if np.prod(result.shape) == 1:
|
| 492 |
+
return np.array(data)
|
| 493 |
+
else:
|
| 494 |
+
return result
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
broadcast_shape_cache = {}
|
| 500 |
+
def broadcast_shape(a_shape, b_shape):
|
| 501 |
+
global broadcast_shape_cache
|
| 502 |
+
|
| 503 |
+
raise Exception('This function is probably a bad idea, because shape is not cached and overquerying can occur.')
|
| 504 |
+
|
| 505 |
+
uid = (a_shape, b_shape)
|
| 506 |
+
|
| 507 |
+
if uid not in broadcast_shape_cache:
|
| 508 |
+
la = len(a_shape)
|
| 509 |
+
lb = len(b_shape)
|
| 510 |
+
ln = la if la > lb else lb
|
| 511 |
+
|
| 512 |
+
ash = np.ones(ln, dtype=np.uint32)
|
| 513 |
+
bsh = np.ones(ln, dtype=np.uint32)
|
| 514 |
+
ash[-la:] = a_shape
|
| 515 |
+
bsh[-lb:] = b_shape
|
| 516 |
+
|
| 517 |
+
our_result = np.max(np.vstack((ash, bsh)), axis=0)
|
| 518 |
+
|
| 519 |
+
if False:
|
| 520 |
+
numpy_result = np.broadcast(np.empty(a_shape), np.empty(b_shape)).shape
|
| 521 |
+
#print 'aaa' + str(our_result)
|
| 522 |
+
#print 'bbb' + str(numpy_result)
|
| 523 |
+
if not np.array_equal(our_result, numpy_result):
|
| 524 |
+
raise Exception('numpy result not equal to our result')
|
| 525 |
+
assert(np.array_equal(our_result, numpy_result))
|
| 526 |
+
|
| 527 |
+
broadcast_shape_cache[uid] = tuple(our_result)
|
| 528 |
+
return broadcast_shape_cache[uid]
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def _broadcast_setup(a, b, wrt):
|
| 532 |
+
if len(set((a.shape, b.shape))) == 1:
|
| 533 |
+
asz = a.size
|
| 534 |
+
IS = np.arange(asz)
|
| 535 |
+
return IS, IS, asz, asz
|
| 536 |
+
input_sz = wrt.r.size
|
| 537 |
+
output_sz = np.broadcast(a.r, b.r).size
|
| 538 |
+
a2 = np.arange(a.size).reshape(a.shape) if wrt is a else np.zeros(a.shape)
|
| 539 |
+
b2 = np.arange(b.size).reshape(b.shape) if (wrt is b and wrt is not a) else np.zeros(b.shape)
|
| 540 |
+
IS = np.arange(output_sz)
|
| 541 |
+
JS = np.asarray((np.add(a2,b2)).ravel(), np.uint32)
|
| 542 |
+
return IS, JS, input_sz, output_sz
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class add(Ch):
|
| 547 |
+
dterms = 'a', 'b'
|
| 548 |
+
|
| 549 |
+
def compute_r(self):
|
| 550 |
+
return self.a.r + self.b.r
|
| 551 |
+
|
| 552 |
+
def compute_dr_wrt(self, wrt):
|
| 553 |
+
if wrt is not self.a and wrt is not self.b:
|
| 554 |
+
return None
|
| 555 |
+
|
| 556 |
+
m = 2. if self.a is self.b else 1.
|
| 557 |
+
return _broadcast_matrix(self.a, self.b, wrt, m)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class subtract(Ch):
|
| 563 |
+
dterms = 'a', 'b'
|
| 564 |
+
|
| 565 |
+
def compute_r(self):
|
| 566 |
+
return self.a.r - self.b.r
|
| 567 |
+
|
| 568 |
+
def compute_dr_wrt(self, wrt):
|
| 569 |
+
if (wrt is self.a) == (wrt is self.b):
|
| 570 |
+
return None
|
| 571 |
+
|
| 572 |
+
m = 1. if wrt is self.a else -1.
|
| 573 |
+
return _broadcast_matrix(self.a, self.b, wrt, m)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class power (Ch):
|
| 580 |
+
"""Given vector \f$x\f$, computes \f$x^2\f$ and \f$\frac{dx^2}{x}\f$"""
|
| 581 |
+
dterms = 'x', 'pow'
|
| 582 |
+
|
| 583 |
+
def compute_r(self):
|
| 584 |
+
return self.safe_power(self.x.r, self.pow.r)
|
| 585 |
+
|
| 586 |
+
def compute_dr_wrt(self, wrt):
|
| 587 |
+
|
| 588 |
+
if wrt is not self.x and wrt is not self.pow:
|
| 589 |
+
return None
|
| 590 |
+
|
| 591 |
+
x, pow = self.x.r, self.pow.r
|
| 592 |
+
result = []
|
| 593 |
+
if wrt is self.x:
|
| 594 |
+
result.append(pow * self.safe_power(x, pow-1.))
|
| 595 |
+
if wrt is self.pow:
|
| 596 |
+
result.append(np.log(x) * self.safe_power(x, pow))
|
| 597 |
+
|
| 598 |
+
data = reduce(lambda x, y : x + y, result).ravel()
|
| 599 |
+
|
| 600 |
+
return _broadcast_matrix(self.x, self.pow, wrt, data)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def safe_power(self, x, sigma):
|
| 604 |
+
# This throws a RuntimeWarning sometimes, but then the infs are corrected below
|
| 605 |
+
result = np.power(x, sigma)
|
| 606 |
+
result.ravel()[np.isinf(result.ravel())] = 0
|
| 607 |
+
return result
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class A_extremum(Ch):
|
| 614 |
+
"""Superclass for various min and max subclasses"""
|
| 615 |
+
dterms = 'a'
|
| 616 |
+
terms = 'axis'
|
| 617 |
+
term_order = 'a', 'axis'
|
| 618 |
+
|
| 619 |
+
def f(self, axis): raise NotImplementedError
|
| 620 |
+
def argf(self, axis): raise NotImplementedError
|
| 621 |
+
|
| 622 |
+
def on_changed(self, which):
|
| 623 |
+
if not hasattr(self, 'axis'):
|
| 624 |
+
self.axis = None
|
| 625 |
+
|
| 626 |
+
def compute_r(self):
|
| 627 |
+
return self.f(self.a.r, axis=self.axis)
|
| 628 |
+
|
| 629 |
+
def compute_dr_wrt(self, wrt):
|
| 630 |
+
if wrt is self.a:
|
| 631 |
+
|
| 632 |
+
mn, stride = self._stride_for_axis(self.axis, self.a.r)
|
| 633 |
+
JS = np.asarray(np.round(mn + stride * self.argf(self.a.r, axis=self.axis)), dtype=np.uint32).ravel()
|
| 634 |
+
IS = np.arange(JS.size)
|
| 635 |
+
data = np.ones(JS.size)
|
| 636 |
+
|
| 637 |
+
if self.r.size * wrt.r.size == 1:
|
| 638 |
+
return data.ravel()[0]
|
| 639 |
+
return sp.csc_matrix((data, (IS, JS)), shape = (self.r.size, wrt.r.size))
|
| 640 |
+
|
| 641 |
+
def _stride_for_axis(self,axis, mtx):
|
| 642 |
+
if axis is None:
|
| 643 |
+
mn = np.array([0])
|
| 644 |
+
stride = np.array([1])
|
| 645 |
+
else:
|
| 646 |
+
# TODO: make this less expensive. Shouldn't need to call
|
| 647 |
+
# np.amin here probably
|
| 648 |
+
idxs = np.arange(mtx.size).reshape(mtx.shape)
|
| 649 |
+
mn = np.amin(idxs, axis=axis)
|
| 650 |
+
mtx_strides = np.array(mtx.strides)
|
| 651 |
+
stride = mtx_strides / np.min(mtx_strides) # go from bytes to num elements
|
| 652 |
+
stride = stride[axis]
|
| 653 |
+
return mn, stride
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
class amax(A_extremum):
|
| 657 |
+
def f(self, *args, **kwargs): return np.amax(*args, **kwargs)
|
| 658 |
+
def argf(self, *args, **kwargs): return np.argmax(*args, **kwargs)
|
| 659 |
+
|
| 660 |
+
max = amax
|
| 661 |
+
|
| 662 |
+
class amin(A_extremum):
|
| 663 |
+
def f(self, *args, **kwargs): return np.amin(*args, **kwargs)
|
| 664 |
+
def argf(self, *args, **kwargs): return np.argmin(*args, **kwargs)
|
| 665 |
+
|
| 666 |
+
min = amin
|
| 667 |
+
|
| 668 |
+
class nanmin(A_extremum):
|
| 669 |
+
def f(self, *args, **kwargs): return np.nanmin(*args, **kwargs)
|
| 670 |
+
def argf(self, *args, **kwargs): return np.nanargmin(*args, **kwargs)
|
| 671 |
+
|
| 672 |
+
class nanmax(A_extremum):
|
| 673 |
+
def f(self, *args, **kwargs): return np.nanmax(*args, **kwargs)
|
| 674 |
+
def argf(self, *args, **kwargs): return np.nanargmax(*args, **kwargs)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class Extremum(Ch):
|
| 678 |
+
dterms = 'a','b'
|
| 679 |
+
|
| 680 |
+
def compute_r(self): return self.f(self.a.r, self.b.r)
|
| 681 |
+
|
| 682 |
+
def compute_dr_wrt(self, wrt):
|
| 683 |
+
if wrt is not self.a and wrt is not self.b:
|
| 684 |
+
return None
|
| 685 |
+
|
| 686 |
+
IS, JS, input_sz, output_sz = _broadcast_setup(self.a, self.b, wrt)
|
| 687 |
+
if wrt is self.a:
|
| 688 |
+
whichmax = (self.r == self.f(self.a.r, self.b.r-self.f(1,-1))).ravel()
|
| 689 |
+
else:
|
| 690 |
+
whichmax = (self.r == self.f(self.b.r, self.a.r-self.f(1,-1))).ravel()
|
| 691 |
+
IS = IS[whichmax]
|
| 692 |
+
JS = JS[whichmax]
|
| 693 |
+
data = np.ones(JS.size)
|
| 694 |
+
|
| 695 |
+
return sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.r.size))
|
| 696 |
+
|
| 697 |
+
class maximum(Extremum):
|
| 698 |
+
def f(self, a, b): return np.maximum(a, b)
|
| 699 |
+
|
| 700 |
+
class minimum(Extremum):
|
| 701 |
+
def f(self, a, b): return np.minimum(a, b)
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
class multiply(Ch):
|
| 705 |
+
dterms = 'a', 'b'
|
| 706 |
+
|
| 707 |
+
def compute_r(self):
|
| 708 |
+
return self.a.r * self.b.r
|
| 709 |
+
|
| 710 |
+
def compute_dr_wrt(self, wrt):
|
| 711 |
+
if wrt is not self.a and wrt is not self.b:
|
| 712 |
+
return None
|
| 713 |
+
|
| 714 |
+
a2 = self.a.r if wrt is self.b else np.ones(self.a.shape)
|
| 715 |
+
b2 = self.b.r if (wrt is self.a and wrt is not self.b) else np.ones(self.b.shape)
|
| 716 |
+
data = (a2 * b2).ravel()
|
| 717 |
+
|
| 718 |
+
if self.a is self.b:
|
| 719 |
+
data *= 2.
|
| 720 |
+
|
| 721 |
+
return _broadcast_matrix(self.a, self.b, wrt, data)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
class dot(Ch):
|
| 728 |
+
dterms = 'a', 'b'
|
| 729 |
+
|
| 730 |
+
def compute_r(self):
|
| 731 |
+
return self.a.r.dot(self.b.r)
|
| 732 |
+
|
| 733 |
+
def compute_d1(self):
|
| 734 |
+
# To stay consistent with numpy, we must upgrade 1D arrays to 2D
|
| 735 |
+
ar = row(self.a.r) if len(self.a.r.shape)<2 else self.a.r.reshape((-1, self.a.r.shape[-1]))
|
| 736 |
+
br = col(self.b.r) if len(self.b.r.shape)<2 else self.b.r.reshape((self.b.r.shape[0], -1))
|
| 737 |
+
|
| 738 |
+
if ar.ndim <= 2:
|
| 739 |
+
return sp.kron(sp.eye(ar.shape[0], ar.shape[0]),br.T)
|
| 740 |
+
else:
|
| 741 |
+
raise NotImplementedError
|
| 742 |
+
|
| 743 |
+
def compute_d2(self):
|
| 744 |
+
|
| 745 |
+
# To stay consistent with numpy, we must upgrade 1D arrays to 2D
|
| 746 |
+
ar = row(self.a.r) if len(self.a.r.shape)<2 else self.a.r.reshape((-1, self.a.r.shape[-1]))
|
| 747 |
+
br = col(self.b.r) if len(self.b.r.shape)<2 else self.b.r.reshape((self.b.r.shape[0], -1))
|
| 748 |
+
|
| 749 |
+
if br.ndim <= 1:
|
| 750 |
+
return self.ar
|
| 751 |
+
elif br.ndim <= 2:
|
| 752 |
+
return sp.kron(ar, sp.eye(br.shape[1],br.shape[1]))
|
| 753 |
+
else:
|
| 754 |
+
raise NotImplementedError
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def compute_dr_wrt(self, wrt):
|
| 758 |
+
|
| 759 |
+
if wrt is self.a and wrt is self.b:
|
| 760 |
+
return self.compute_d1() + self.compute_d2()
|
| 761 |
+
elif wrt is self.a:
|
| 762 |
+
return self.compute_d1()
|
| 763 |
+
elif wrt is self.b:
|
| 764 |
+
return self.compute_d2()
|
| 765 |
+
|
| 766 |
+
class BinaryElemwiseNoDrv(Ch):
|
| 767 |
+
dterms = 'x1', 'x2'
|
| 768 |
+
|
| 769 |
+
def compute_r(self):
|
| 770 |
+
return self._f(self.x1.r, self.x2.r)
|
| 771 |
+
|
| 772 |
+
def compute_dr_wrt(self, wrt):
|
| 773 |
+
return None
|
| 774 |
+
|
| 775 |
+
class greater(BinaryElemwiseNoDrv):
|
| 776 |
+
def _f(self, a, b): return np.greater(a,b)
|
| 777 |
+
|
| 778 |
+
class greater_equal(BinaryElemwiseNoDrv):
|
| 779 |
+
def _f(self, a, b): return np.greater_equal(a,b)
|
| 780 |
+
|
| 781 |
+
class less(BinaryElemwiseNoDrv):
|
| 782 |
+
def _f(self, a, b): return np.less(a,b)
|
| 783 |
+
|
| 784 |
+
class less_equal(BinaryElemwiseNoDrv):
|
| 785 |
+
def _f(self, a, b): return np.less_equal(a,b)
|
| 786 |
+
|
| 787 |
+
class equal(BinaryElemwiseNoDrv):
|
| 788 |
+
def _f(self, a, b): return np.equal(a,b)
|
| 789 |
+
|
| 790 |
+
class not_equal(BinaryElemwiseNoDrv):
|
| 791 |
+
def _f(self, a, b): return np.not_equal(a,b)
|
| 792 |
+
|
| 793 |
+
def nonzero(a):
|
| 794 |
+
if hasattr(a, 'compute_r'):
|
| 795 |
+
a = a.r
|
| 796 |
+
return np.nonzero(a)
|
| 797 |
+
|
| 798 |
+
# Pull the code for tensordot in from numpy and reinterpret it using chumpy ops
|
| 799 |
+
import os
|
| 800 |
+
source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'np_tensordot.py')
|
| 801 |
+
with open(source_path, 'r') as f:
|
| 802 |
+
source_lines = f.readlines()
|
| 803 |
+
exec(''.join(source_lines))
|
| 804 |
+
__all__ += ['tensordot']
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def main():
|
| 809 |
+
pass
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
if __name__ == '__main__':
|
| 813 |
+
main()
|
| 814 |
+
|
vendor/chumpy/chumpy/ch_random.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author(s): Matthew Loper
|
| 3 |
+
|
| 4 |
+
See LICENCE.txt for licensing and contact information.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy.random
|
| 8 |
+
from .ch import Ch
|
| 9 |
+
|
| 10 |
+
api_not_implemented = ['choice','bytes','shuffle','permutation']
|
| 11 |
+
|
| 12 |
+
api_wrapped_simple = [
|
| 13 |
+
# simple random data
|
| 14 |
+
'rand','randn','randint','random_integers','random_sample','random','ranf','sample',
|
| 15 |
+
|
| 16 |
+
# distributions
|
| 17 |
+
'beta','binomial','chisquare','dirichlet','exponential','f','gamma','geometric','gumbel','hypergeometric',
|
| 18 |
+
'laplace','logistic','lognormal','logseries','multinomial','multivariate_normal','negative_binomial',
|
| 19 |
+
'noncentral_chisquare','noncentral_f','normal','pareto','poisson','power','rayleigh','standard_cauchy',
|
| 20 |
+
'standard_exponential','standard_gamma','standard_normal','standard_t','triangular','uniform','vonmises',
|
| 21 |
+
'wald','weibull','zipf']
|
| 22 |
+
|
| 23 |
+
api_wrapped_direct = ['seed', 'get_state', 'set_state']
|
| 24 |
+
|
| 25 |
+
for rtn in api_wrapped_simple:
|
| 26 |
+
exec('def %s(*args, **kwargs) : return Ch(numpy.random.%s(*args, **kwargs))' % (rtn, rtn))
|
| 27 |
+
|
| 28 |
+
for rtn in api_wrapped_direct:
|
| 29 |
+
exec('%s = numpy.random.%s' % (rtn, rtn))
|
| 30 |
+
|
| 31 |
+
__all__ = api_wrapped_simple + api_wrapped_direct
|
| 32 |
+
|
vendor/chumpy/chumpy/extras.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'matt'
|
| 2 |
+
|
| 3 |
+
from . import ch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from .utils import row, col
|
| 6 |
+
import scipy.sparse as sp
|
| 7 |
+
import scipy.special
|
| 8 |
+
|
| 9 |
+
class Interp3D(ch.Ch):
|
| 10 |
+
dterms = 'locations'
|
| 11 |
+
terms = 'image'
|
| 12 |
+
|
| 13 |
+
def on_changed(self, which):
|
| 14 |
+
if 'image' in which:
|
| 15 |
+
self.gx, self.gy, self.gz = np.gradient(self.image)
|
| 16 |
+
|
| 17 |
+
def compute_r(self):
|
| 18 |
+
locations = self.locations.r.copy()
|
| 19 |
+
for i in range(3):
|
| 20 |
+
locations[:,i] = np.clip(locations[:,i], 0, self.image.shape[i]-1)
|
| 21 |
+
locs = np.floor(locations).astype(np.uint32)
|
| 22 |
+
result = self.image[locs[:,0], locs[:,1], locs[:,2]]
|
| 23 |
+
offset = (locations - locs)
|
| 24 |
+
dr = self.dr_wrt(self.locations).dot(offset.ravel())
|
| 25 |
+
return result + dr
|
| 26 |
+
|
| 27 |
+
def compute_dr_wrt(self, wrt):
|
| 28 |
+
if wrt is self.locations:
|
| 29 |
+
locations = self.locations.r.copy()
|
| 30 |
+
for i in range(3):
|
| 31 |
+
locations[:,i] = np.clip(locations[:,i], 0, self.image.shape[i]-1)
|
| 32 |
+
locations = locations.astype(np.uint32)
|
| 33 |
+
|
| 34 |
+
xc = col(self.gx[locations[:,0], locations[:,1], locations[:,2]])
|
| 35 |
+
yc = col(self.gy[locations[:,0], locations[:,1], locations[:,2]])
|
| 36 |
+
zc = col(self.gz[locations[:,0], locations[:,1], locations[:,2]])
|
| 37 |
+
|
| 38 |
+
data = np.vstack([xc.ravel(), yc.ravel(), zc.ravel()]).T.copy()
|
| 39 |
+
JS = np.arange(locations.size)
|
| 40 |
+
IS = JS // 3
|
| 41 |
+
|
| 42 |
+
return sp.csc_matrix((data.ravel(), (IS, JS)))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class gamma(ch.Ch):
|
| 46 |
+
dterms = 'x',
|
| 47 |
+
|
| 48 |
+
def compute_r(self):
|
| 49 |
+
return scipy.special.gamma(self.x.r)
|
| 50 |
+
|
| 51 |
+
def compute_dr_wrt(self, wrt):
|
| 52 |
+
if wrt is self.x:
|
| 53 |
+
d = scipy.special.polygamma(0, self.x.r)*self.r
|
| 54 |
+
return sp.diags([d.ravel()], [0])
|
| 55 |
+
|
| 56 |
+
# This function is based directly on the "moment" function
|
| 57 |
+
# in scipy, specifically in mstats_basic.py.
|
| 58 |
+
def moment(a, moment=1, axis=0):
|
| 59 |
+
if moment == 1:
|
| 60 |
+
# By definition the first moment about the mean is 0.
|
| 61 |
+
shape = list(a.shape)
|
| 62 |
+
del shape[axis]
|
| 63 |
+
if shape:
|
| 64 |
+
# return an actual array of the appropriate shape
|
| 65 |
+
return ch.zeros(shape, dtype=float)
|
| 66 |
+
else:
|
| 67 |
+
# the input was 1D, so return a scalar instead of a rank-0 array
|
| 68 |
+
return np.float64(0.0)
|
| 69 |
+
else:
|
| 70 |
+
mn = ch.expand_dims(a.mean(axis=axis), axis)
|
| 71 |
+
s = ch.power((a-mn), moment)
|
| 72 |
+
return s.mean(axis=axis)
|
vendor/chumpy/chumpy/linalg.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Author(s): Matthew Loper
|
| 6 |
+
|
| 7 |
+
See LICENCE.txt for licensing and contact information.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = ['inv', 'svd', 'det', 'slogdet', 'pinv', 'lstsq', 'norm']
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import scipy.sparse as sp
|
| 15 |
+
from .ch import Ch, depends_on
|
| 16 |
+
from .ch_ops import NanDivide
|
| 17 |
+
from .ch_ops import asarray as ch_asarray
|
| 18 |
+
from .ch_ops import sqrt as ch_sqrt
|
| 19 |
+
from .ch_ops import sum as ch_sum
|
| 20 |
+
from .reordering import concatenate as ch_concatenate
|
| 21 |
+
from .ch_random import randn as ch_random_randn
|
| 22 |
+
from .utils import row, col
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
asarray = ch_asarray
|
| 27 |
+
import inspect
|
| 28 |
+
exec(''.join(inspect.getsourcelines(np.linalg.tensorinv)[0]))
|
| 29 |
+
__all__.append('tensorinv')
|
| 30 |
+
except: pass
|
| 31 |
+
|
| 32 |
+
def norm(x, ord=None, axis=None):
|
| 33 |
+
if ord is not None or axis is not None:
|
| 34 |
+
raise NotImplementedError("'ord' and 'axis' should be None for now.")
|
| 35 |
+
|
| 36 |
+
return ch_sqrt(ch_sum(x**2))
|
| 37 |
+
|
| 38 |
+
# This version works but derivatives are too slow b/c of nested loop in Svd implementation.
|
| 39 |
+
# def lstsq(a, b):
|
| 40 |
+
# u, s, v = Svd(a)
|
| 41 |
+
# x = (v.T / s).dot(u.T.dot(b))
|
| 42 |
+
# residuals = NotImplementedError # ch_sum((a.dot(x) - b)**2, axis=0)
|
| 43 |
+
# rank = NotImplementedError
|
| 44 |
+
# s = NotImplementedError
|
| 45 |
+
# return x, residuals, rank, s
|
| 46 |
+
|
| 47 |
+
def lstsq(a, b, rcond=-1):
|
| 48 |
+
if rcond != -1:
|
| 49 |
+
raise Exception('non-default rcond not yet implemented')
|
| 50 |
+
|
| 51 |
+
x = Ch(lambda a, b : pinv(a).dot(b))
|
| 52 |
+
x.a = a
|
| 53 |
+
x.b = b
|
| 54 |
+
residuals = ch_sum( (x.a.dot(x) - x.b) **2 , axis=0)
|
| 55 |
+
rank = NotImplementedError
|
| 56 |
+
s = NotImplementedError
|
| 57 |
+
|
| 58 |
+
return x, residuals, rank, s
|
| 59 |
+
|
| 60 |
+
def Svd(x, full_matrices=0, compute_uv=1):
|
| 61 |
+
|
| 62 |
+
if full_matrices != 0:
|
| 63 |
+
raise Exception('full_matrices must be 0')
|
| 64 |
+
if compute_uv != 1:
|
| 65 |
+
raise Exception('compute_uv must be 1')
|
| 66 |
+
|
| 67 |
+
need_transpose = x.shape[0] < x.shape[1]
|
| 68 |
+
|
| 69 |
+
if need_transpose:
|
| 70 |
+
x = x.T
|
| 71 |
+
|
| 72 |
+
svd_d = SvdD(x=x)
|
| 73 |
+
svd_v = SvdV(x=x, svd_d=svd_d)
|
| 74 |
+
svd_u = SvdU(x=x, svd_d=svd_d, svd_v=svd_v)
|
| 75 |
+
|
| 76 |
+
if need_transpose:
|
| 77 |
+
return svd_v, svd_d, svd_u.T
|
| 78 |
+
else:
|
| 79 |
+
return svd_u, svd_d, svd_v.T
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Pinv(Ch):
|
| 83 |
+
dterms = 'mtx'
|
| 84 |
+
|
| 85 |
+
def on_changed(self, which):
|
| 86 |
+
mtx = self.mtx
|
| 87 |
+
if mtx.shape[1] > mtx.shape[0]:
|
| 88 |
+
result = mtx.T.dot(Inv(mtx.dot(mtx.T)))
|
| 89 |
+
else:
|
| 90 |
+
result = Inv(mtx.T.dot(mtx)).dot(mtx.T)
|
| 91 |
+
self._result = result
|
| 92 |
+
|
| 93 |
+
def compute_r(self):
|
| 94 |
+
return self._result.r
|
| 95 |
+
|
| 96 |
+
def compute_dr_wrt(self, wrt):
|
| 97 |
+
if wrt is self.mtx:
|
| 98 |
+
return self._result.dr_wrt(self.mtx)
|
| 99 |
+
|
| 100 |
+
# Couldn't make the SVD version of pinv work yet...
|
| 101 |
+
#
|
| 102 |
+
# class Pinv(Ch):
|
| 103 |
+
# dterms = 'mtx'
|
| 104 |
+
#
|
| 105 |
+
# def on_changed(self, which):
|
| 106 |
+
# u, s, v = Svd(self.mtx)
|
| 107 |
+
# result = (v.T * (NanDivide(1.,row(s)))).dot(u.T)
|
| 108 |
+
# self.add_dterm('_result', result)
|
| 109 |
+
#
|
| 110 |
+
# def compute_r(self):
|
| 111 |
+
# return self._result.r
|
| 112 |
+
#
|
| 113 |
+
# def compute_dr_wrt(self, wrt):
|
| 114 |
+
# if wrt is self._result:
|
| 115 |
+
# return 1
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class LogAbsDet(Ch):
|
| 120 |
+
dterms = 'x'
|
| 121 |
+
|
| 122 |
+
def on_changed(self, which):
|
| 123 |
+
self.sign, self.slogdet = np.linalg.slogdet(self.x.r)
|
| 124 |
+
|
| 125 |
+
def compute_r(self):
|
| 126 |
+
return self.slogdet
|
| 127 |
+
|
| 128 |
+
def compute_dr_wrt(self, wrt):
|
| 129 |
+
if wrt is self.x:
|
| 130 |
+
return row(np.linalg.inv(self.x.r).T)
|
| 131 |
+
|
| 132 |
+
class SignLogAbsDet(Ch):
|
| 133 |
+
dterms = 'logabsdet',
|
| 134 |
+
|
| 135 |
+
def compute_r(self):
|
| 136 |
+
_ = self.logabsdet.r
|
| 137 |
+
return self.logabsdet.sign
|
| 138 |
+
|
| 139 |
+
def compute_dr_wrt(self, wrt):
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class Det(Ch):
|
| 144 |
+
dterms = 'x'
|
| 145 |
+
|
| 146 |
+
def compute_r(self):
|
| 147 |
+
return np.linalg.det(self.x.r)
|
| 148 |
+
|
| 149 |
+
def compute_dr_wrt(self, wrt):
|
| 150 |
+
if wrt is self.x:
|
| 151 |
+
return row(self.r * np.linalg.inv(self.x.r).T)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class Inv(Ch):
|
| 155 |
+
dterms = 'a'
|
| 156 |
+
|
| 157 |
+
def compute_r(self):
|
| 158 |
+
return np.linalg.inv(self.a.r)
|
| 159 |
+
|
| 160 |
+
def compute_dr_wrt(self, wrt):
|
| 161 |
+
if wrt is not self.a:
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
Ainv = self.r
|
| 165 |
+
|
| 166 |
+
if Ainv.ndim <= 2:
|
| 167 |
+
return -np.kron(Ainv, Ainv.T)
|
| 168 |
+
else:
|
| 169 |
+
Ainv = np.reshape(Ainv, (-1, Ainv.shape[-2], Ainv.shape[-1]))
|
| 170 |
+
AinvT = np.rollaxis(Ainv, -1, -2)
|
| 171 |
+
AinvT = np.reshape(AinvT, (-1, AinvT.shape[-2], AinvT.shape[-1]))
|
| 172 |
+
result = np.dstack([-np.kron(Ainv[i], AinvT[i]).T for i in range(Ainv.shape[0])]).T
|
| 173 |
+
result = sp.block_diag(result)
|
| 174 |
+
|
| 175 |
+
return result
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class SvdD(Ch):
|
| 179 |
+
dterms = 'x'
|
| 180 |
+
|
| 181 |
+
@depends_on('x')
|
| 182 |
+
def UDV(self):
|
| 183 |
+
result = np.linalg.svd(self.x.r, full_matrices=False)
|
| 184 |
+
result = [result[0], result[1], result[2].T]
|
| 185 |
+
result[1][np.abs(result[1]) < np.spacing(1)] = 0.
|
| 186 |
+
return result
|
| 187 |
+
|
| 188 |
+
def compute_r(self):
|
| 189 |
+
return self.UDV[1]
|
| 190 |
+
|
| 191 |
+
def compute_dr_wrt(self, wrt):
|
| 192 |
+
if wrt is not self.x:
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
u, d, v = self.UDV
|
| 196 |
+
shp = self.x.r.shape
|
| 197 |
+
u = u[:shp[0], :shp[1]]
|
| 198 |
+
v = v[:shp[1], :d.size]
|
| 199 |
+
|
| 200 |
+
result = np.einsum('ik,jk->kij', u, v)
|
| 201 |
+
result = result.reshape((result.shape[0], -1))
|
| 202 |
+
return result
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class SvdV(Ch):
|
| 206 |
+
terms = 'svd_d'
|
| 207 |
+
dterms = 'x'
|
| 208 |
+
|
| 209 |
+
def compute_r(self):
|
| 210 |
+
return self.svd_d.UDV[2]
|
| 211 |
+
|
| 212 |
+
def compute_dr_wrt(self, wrt):
|
| 213 |
+
if wrt is not self.x:
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
U,_D,V = self.svd_d.UDV
|
| 217 |
+
|
| 218 |
+
shp = self.svd_d.x.r.shape
|
| 219 |
+
mxsz = max(shp[0], shp[1])
|
| 220 |
+
#mnsz = min(shp[0], shp[1])
|
| 221 |
+
D = np.zeros(mxsz)
|
| 222 |
+
D[:_D.size] = _D
|
| 223 |
+
|
| 224 |
+
omega = np.zeros((shp[0], shp[1], shp[1], shp[1]))
|
| 225 |
+
|
| 226 |
+
M = shp[0]
|
| 227 |
+
N = shp[1]
|
| 228 |
+
|
| 229 |
+
assert(M >= N)
|
| 230 |
+
|
| 231 |
+
for i in range(shp[0]):
|
| 232 |
+
for j in range(shp[1]):
|
| 233 |
+
for k in range(N):
|
| 234 |
+
for l in range(k+1, N):
|
| 235 |
+
mtx = np.array([
|
| 236 |
+
[D[l],D[k]],
|
| 237 |
+
[D[k],D[l]]])
|
| 238 |
+
|
| 239 |
+
rhs = np.array([U[i,k]*V[j,l], -U[i,l]*V[j,k]])
|
| 240 |
+
result = np.linalg.solve(mtx, rhs)
|
| 241 |
+
|
| 242 |
+
omega[i,j,k,l] = result[1]
|
| 243 |
+
omega[i,j,l,k] = -result[1]
|
| 244 |
+
|
| 245 |
+
#print 'v size is %s' % (str(V.shape),)
|
| 246 |
+
#print 'v omega size is %s' % (str(omega.shape),)
|
| 247 |
+
assert(V.shape[1] == omega.shape[2])
|
| 248 |
+
return np.einsum('ak,ijkl->alij', -V, omega).reshape((self.r.size, wrt.r.size))
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class SvdU(Ch):
|
| 252 |
+
dterms = 'x'
|
| 253 |
+
terms = 'svd_d', 'svd_v'
|
| 254 |
+
|
| 255 |
+
def compute_r(self):
|
| 256 |
+
return self.svd_d.UDV[0]
|
| 257 |
+
|
| 258 |
+
def compute_dr_wrt(self, wrt):
|
| 259 |
+
if wrt is self.x:
|
| 260 |
+
# return (
|
| 261 |
+
# self.svd_d.x.dot(self.svd_v)
|
| 262 |
+
# /
|
| 263 |
+
# self.svd_d.reshape((1,-1))
|
| 264 |
+
# ).dr_wrt(self.svd_d.x)
|
| 265 |
+
return (
|
| 266 |
+
NanDivide(
|
| 267 |
+
self.svd_d.x.dot(self.svd_v),
|
| 268 |
+
self.svd_d.reshape((1,-1)))
|
| 269 |
+
).dr_wrt(self.svd_d.x)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
inv = Inv
|
| 273 |
+
svd = Svd
|
| 274 |
+
det = Det
|
| 275 |
+
pinv = Pinv
|
| 276 |
+
|
| 277 |
+
def slogdet(*args):
|
| 278 |
+
n = len(args)
|
| 279 |
+
if n == 1:
|
| 280 |
+
r2 = LogAbsDet(x=args[0])
|
| 281 |
+
r1 = SignLogAbsDet(r2)
|
| 282 |
+
return r1, r2
|
| 283 |
+
else:
|
| 284 |
+
r2 = [LogAbsDet(x=arg) for arg in args]
|
| 285 |
+
r1 = [SignLogAbsDet(r) for r in r2]
|
| 286 |
+
r2 = ch_concatenate(r2)
|
| 287 |
+
return r1, r2
|
| 288 |
+
|
| 289 |
+
def main():
|
| 290 |
+
|
| 291 |
+
tmp = ch_random_randn(100).reshape((10,10))
|
| 292 |
+
print('chumpy version: ' + str(slogdet(tmp)[1].r))
|
| 293 |
+
print('old version:' + str(np.linalg.slogdet(tmp.r)[1]))
|
| 294 |
+
|
| 295 |
+
eps = 1e-10
|
| 296 |
+
diff = np.random.rand(100) * eps
|
| 297 |
+
diff_reshaped = diff.reshape((10,10))
|
| 298 |
+
print(np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1])
|
| 299 |
+
print(slogdet(tmp)[1].dr_wrt(tmp).dot(diff))
|
| 300 |
+
|
| 301 |
+
print(np.linalg.slogdet(tmp.r)[0])
|
| 302 |
+
print(slogdet(tmp)[0])
|
| 303 |
+
|
| 304 |
+
if __name__ == '__main__':
|
| 305 |
+
main()
|
| 306 |
+
|
vendor/chumpy/chumpy/logic.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author(s): Matthew Loper
|
| 3 |
+
|
| 4 |
+
See LICENCE.txt for licensing and contact information.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__author__ = 'matt'
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = [] # added to incrementally below
|
| 11 |
+
|
| 12 |
+
from . import ch
|
| 13 |
+
from .ch import Ch
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
class LogicFunc(Ch):
|
| 17 |
+
dterms = 'a' # we keep this here so that changes to children of "a" will trigger cache changes
|
| 18 |
+
terms = 'args', 'kwargs', 'funcname'
|
| 19 |
+
|
| 20 |
+
def compute_r(self):
|
| 21 |
+
arr = self.a
|
| 22 |
+
fn = getattr(np, self.funcname)
|
| 23 |
+
return fn(arr, *self.args, **self.kwargs)
|
| 24 |
+
|
| 25 |
+
def compute_dr_wrt(self, wrt):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
unaries = 'all', 'any', 'isfinite', 'isinf', 'isnan', 'isneginf', 'isposinf', 'logical_not'
|
| 30 |
+
for unary in unaries:
|
| 31 |
+
exec("def %s(a, *args, **kwargs): return LogicFunc(a=a, args=args, kwargs=kwargs, funcname='%s')" % (unary, unary))
|
| 32 |
+
__all__ += unaries
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == '__main__':
|
| 37 |
+
from . import ch
|
| 38 |
+
print(all(np.array([1,2,3])))
|
| 39 |
+
print(isinf(np.array([0,2,3])))
|
vendor/chumpy/chumpy/monitor.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Logging service for tracking dr tree changes from root objective
|
| 3 |
+
and record every step that incrementally changes the dr tree
|
| 4 |
+
|
| 5 |
+
'''
|
| 6 |
+
import os, sys, time
|
| 7 |
+
import json
|
| 8 |
+
import psutil
|
| 9 |
+
|
| 10 |
+
import scipy.sparse as sp
|
| 11 |
+
import numpy as np
|
| 12 |
+
from . import reordering
|
| 13 |
+
|
| 14 |
+
_TWO_20 = float(2 **20)
|
| 15 |
+
|
| 16 |
+
'''
|
| 17 |
+
memory utils
|
| 18 |
+
|
| 19 |
+
'''
|
| 20 |
+
def pdb_mem():
|
| 21 |
+
from .monitor import get_current_memory
|
| 22 |
+
mem = get_current_memory()
|
| 23 |
+
if mem > 7000:
|
| 24 |
+
import pdb;pdb.set_trace()
|
| 25 |
+
|
| 26 |
+
def get_peak_mem():
|
| 27 |
+
'''
|
| 28 |
+
this returns peak memory use since process starts till the moment its called
|
| 29 |
+
'''
|
| 30 |
+
import resource
|
| 31 |
+
rusage_denom = 1024.
|
| 32 |
+
if sys.platform == 'darwin':
|
| 33 |
+
# ... it seems that in OSX the output is different units ...
|
| 34 |
+
rusage_denom = rusage_denom * rusage_denom
|
| 35 |
+
mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / rusage_denom
|
| 36 |
+
return mem
|
| 37 |
+
|
| 38 |
+
def get_current_memory():
|
| 39 |
+
p = psutil.Process(os.getpid())
|
| 40 |
+
mem = p.memory_info()[0]/_TWO_20
|
| 41 |
+
|
| 42 |
+
return mem
|
| 43 |
+
|
| 44 |
+
'''
|
| 45 |
+
Helper for Profiler
|
| 46 |
+
'''
|
| 47 |
+
|
| 48 |
+
def build_cache_info(k, v, info_dict):
|
| 49 |
+
if v is not None:
|
| 50 |
+
issparse = sp.issparse(v)
|
| 51 |
+
size = v.size
|
| 52 |
+
if issparse:
|
| 53 |
+
nonzero = len(v.data)
|
| 54 |
+
else:
|
| 55 |
+
nonzero = np.count_nonzero(v)
|
| 56 |
+
info_dict[k.short_name] = {
|
| 57 |
+
'sparse': issparse,
|
| 58 |
+
'size' : str(size),
|
| 59 |
+
'nonzero' : nonzero,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def cache_info(ch_node):
|
| 64 |
+
result = {}
|
| 65 |
+
if isinstance(ch_node, reordering.Concatenate) and hasattr(ch_node, 'dr_cached') and len(ch_node.dr_cached) > 0:
|
| 66 |
+
for k, v in ch_node.dr_cached.items():
|
| 67 |
+
build_cache_info(k, v, result)
|
| 68 |
+
elif len(ch_node._cache['drs']) > 0:
|
| 69 |
+
for k, v in ch_node._cache['drs'].items():
|
| 70 |
+
build_cache_info(k, v, result)
|
| 71 |
+
|
| 72 |
+
return result
|
| 73 |
+
|
| 74 |
+
class DrWrtProfiler(object):
|
| 75 |
+
base_path = os.path.abspath('profiles')
|
| 76 |
+
|
| 77 |
+
def __init__(self, root, base_path=None):
|
| 78 |
+
self.root = root.obj
|
| 79 |
+
self.history = []
|
| 80 |
+
|
| 81 |
+
ts = time.time()
|
| 82 |
+
if base_path:
|
| 83 |
+
self.base_path = base_path
|
| 84 |
+
|
| 85 |
+
self.path = os.path.join(self.base_path, 'profile_%s.json' % str(ts))
|
| 86 |
+
self.root_path = os.path.join(self.base_path, 'root_%s.json' % str(ts))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
with open(self.root_path, 'w') as f:
|
| 90 |
+
json.dump(self.dump_tree(self.root), f, indent=4)
|
| 91 |
+
|
| 92 |
+
def dump_tree(self, node):
|
| 93 |
+
if not hasattr(node, 'dterms'):
|
| 94 |
+
return []
|
| 95 |
+
|
| 96 |
+
node_dict = self.serialize_node(node, verbose=False)
|
| 97 |
+
if hasattr(node, 'visited') and node.visited:
|
| 98 |
+
node_dict.update({'indirect':True})
|
| 99 |
+
return node_dict
|
| 100 |
+
|
| 101 |
+
node.visited = True
|
| 102 |
+
children_list = []
|
| 103 |
+
for dterm in node.dterms:
|
| 104 |
+
if hasattr(node, dterm):
|
| 105 |
+
child = getattr(node, dterm)
|
| 106 |
+
if hasattr(child, 'dterms') or hasattr(child, 'terms'):
|
| 107 |
+
children_list.append(self.dump_tree(child))
|
| 108 |
+
node_dict.update({'children':children_list})
|
| 109 |
+
return node_dict
|
| 110 |
+
|
| 111 |
+
def serialize_node(self, ch_node, verbose=True):
|
| 112 |
+
node_id = id(ch_node)
|
| 113 |
+
name = ch_node.short_name
|
| 114 |
+
ts = time.time()
|
| 115 |
+
status = ch_node._status
|
| 116 |
+
mem = get_current_memory()
|
| 117 |
+
node_cache_info = cache_info(ch_node)
|
| 118 |
+
|
| 119 |
+
rec = {
|
| 120 |
+
'id': str(node_id),
|
| 121 |
+
'indirect' : False,
|
| 122 |
+
}
|
| 123 |
+
if verbose:
|
| 124 |
+
rec.update({
|
| 125 |
+
'name':name,
|
| 126 |
+
'ts' : ts,
|
| 127 |
+
'status':status,
|
| 128 |
+
'mem': mem,
|
| 129 |
+
'cache': node_cache_info,
|
| 130 |
+
})
|
| 131 |
+
return rec
|
| 132 |
+
|
| 133 |
+
def show_tree(self, label):
|
| 134 |
+
'''
|
| 135 |
+
show tree from the root node
|
| 136 |
+
'''
|
| 137 |
+
self.root.show_tree_cache(label)
|
| 138 |
+
|
| 139 |
+
def record(self, ch_node):
|
| 140 |
+
'''
|
| 141 |
+
Incremental changes
|
| 142 |
+
'''
|
| 143 |
+
rec = self.serialize_node(ch_node)
|
| 144 |
+
self.history.append(rec)
|
| 145 |
+
|
| 146 |
+
def harvest(self):
|
| 147 |
+
print('collecting and dump to file %s' % self.path)
|
| 148 |
+
with open(self.path, 'w') as f:
|
| 149 |
+
json.dump(self.history, f, indent=4)
|
vendor/chumpy/chumpy/np_tensordot.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Up to numpy 1.13, the numpy implementation of tensordot could be
|
| 2 |
+
# reinterpreted using chumpy. With numpy 1.14 the implementation started using
|
| 3 |
+
# ufunc.multiply.reduce which can't be understood by chumpy. This is the
|
| 4 |
+
# chumpy-compatible implementation of tensodrot from numpy 1.13.3.
|
| 5 |
+
#
|
| 6 |
+
# i.e.
|
| 7 |
+
#
|
| 8 |
+
# import inspect
|
| 9 |
+
# with open('np_tensordot.py', 'w') as f:
|
| 10 |
+
# f.write(''.join(inspect.getsourcelines(np.tensordot)[0]))
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Copyright (c) 2005-2017, NumPy Developers.
|
| 14 |
+
All rights reserved.
|
| 15 |
+
|
| 16 |
+
Redistribution and use in source and binary forms, with or without
|
| 17 |
+
modification, are permitted provided that the following conditions are
|
| 18 |
+
met:
|
| 19 |
+
|
| 20 |
+
* Redistributions of source code must retain the above copyright
|
| 21 |
+
notice, this list of conditions and the following disclaimer.
|
| 22 |
+
|
| 23 |
+
* Redistributions in binary form must reproduce the above
|
| 24 |
+
copyright notice, this list of conditions and the following
|
| 25 |
+
disclaimer in the documentation and/or other materials provided
|
| 26 |
+
with the distribution.
|
| 27 |
+
|
| 28 |
+
* Neither the name of the NumPy Developers nor the names of any
|
| 29 |
+
contributors may be used to endorse or promote products derived
|
| 30 |
+
from this software without specific prior written permission.
|
| 31 |
+
|
| 32 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 33 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 34 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 35 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 36 |
+
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 37 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 38 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 39 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 40 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 41 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 42 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def tensordot(a, b, axes=2):
|
| 46 |
+
"""
|
| 47 |
+
Compute tensor dot product along specified axes for arrays >= 1-D.
|
| 48 |
+
|
| 49 |
+
Given two tensors (arrays of dimension greater than or equal to one),
|
| 50 |
+
`a` and `b`, and an array_like object containing two array_like
|
| 51 |
+
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
|
| 52 |
+
elements (components) over the axes specified by ``a_axes`` and
|
| 53 |
+
``b_axes``. The third argument can be a single non-negative
|
| 54 |
+
integer_like scalar, ``N``; if it is such, then the last ``N``
|
| 55 |
+
dimensions of `a` and the first ``N`` dimensions of `b` are summed
|
| 56 |
+
over.
|
| 57 |
+
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
a, b : array_like, len(shape) >= 1
|
| 61 |
+
Tensors to "dot".
|
| 62 |
+
|
| 63 |
+
axes : int or (2,) array_like
|
| 64 |
+
* integer_like
|
| 65 |
+
If an int N, sum over the last N axes of `a` and the first N axes
|
| 66 |
+
of `b` in order. The sizes of the corresponding axes must match.
|
| 67 |
+
* (2,) array_like
|
| 68 |
+
Or, a list of axes to be summed over, first sequence applying to `a`,
|
| 69 |
+
second to `b`. Both elements array_like must be of the same length.
|
| 70 |
+
|
| 71 |
+
See Also
|
| 72 |
+
--------
|
| 73 |
+
dot, einsum
|
| 74 |
+
|
| 75 |
+
Notes
|
| 76 |
+
-----
|
| 77 |
+
Three common use cases are:
|
| 78 |
+
* ``axes = 0`` : tensor product :math:`a\\otimes b`
|
| 79 |
+
* ``axes = 1`` : tensor dot product :math:`a\\cdot b`
|
| 80 |
+
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
|
| 81 |
+
|
| 82 |
+
When `axes` is integer_like, the sequence for evaluation will be: first
|
| 83 |
+
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
|
| 84 |
+
Nth axis in `b` last.
|
| 85 |
+
|
| 86 |
+
When there is more than one axis to sum over - and they are not the last
|
| 87 |
+
(first) axes of `a` (`b`) - the argument `axes` should consist of
|
| 88 |
+
two sequences of the same length, with the first axis to sum over given
|
| 89 |
+
first in both sequences, the second axis second, and so forth.
|
| 90 |
+
|
| 91 |
+
Examples
|
| 92 |
+
--------
|
| 93 |
+
A "traditional" example:
|
| 94 |
+
|
| 95 |
+
>>> a = np.arange(60.).reshape(3,4,5)
|
| 96 |
+
>>> b = np.arange(24.).reshape(4,3,2)
|
| 97 |
+
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
|
| 98 |
+
>>> c.shape
|
| 99 |
+
(5, 2)
|
| 100 |
+
>>> c
|
| 101 |
+
array([[ 4400., 4730.],
|
| 102 |
+
[ 4532., 4874.],
|
| 103 |
+
[ 4664., 5018.],
|
| 104 |
+
[ 4796., 5162.],
|
| 105 |
+
[ 4928., 5306.]])
|
| 106 |
+
>>> # A slower but equivalent way of computing the same...
|
| 107 |
+
>>> d = np.zeros((5,2))
|
| 108 |
+
>>> for i in range(5):
|
| 109 |
+
... for j in range(2):
|
| 110 |
+
... for k in range(3):
|
| 111 |
+
... for n in range(4):
|
| 112 |
+
... d[i,j] += a[k,n,i] * b[n,k,j]
|
| 113 |
+
>>> c == d
|
| 114 |
+
array([[ True, True],
|
| 115 |
+
[ True, True],
|
| 116 |
+
[ True, True],
|
| 117 |
+
[ True, True],
|
| 118 |
+
[ True, True]], dtype=bool)
|
| 119 |
+
|
| 120 |
+
An extended example taking advantage of the overloading of + and \\*:
|
| 121 |
+
|
| 122 |
+
>>> a = np.array(range(1, 9))
|
| 123 |
+
>>> a.shape = (2, 2, 2)
|
| 124 |
+
>>> A = np.array(('a', 'b', 'c', 'd'), dtype=object)
|
| 125 |
+
>>> A.shape = (2, 2)
|
| 126 |
+
>>> a; A
|
| 127 |
+
array([[[1, 2],
|
| 128 |
+
[3, 4]],
|
| 129 |
+
[[5, 6],
|
| 130 |
+
[7, 8]]])
|
| 131 |
+
array([[a, b],
|
| 132 |
+
[c, d]], dtype=object)
|
| 133 |
+
|
| 134 |
+
>>> np.tensordot(a, A) # third argument default is 2 for double-contraction
|
| 135 |
+
array([abbcccdddd, aaaaabbbbbbcccccccdddddddd], dtype=object)
|
| 136 |
+
|
| 137 |
+
>>> np.tensordot(a, A, 1)
|
| 138 |
+
array([[[acc, bdd],
|
| 139 |
+
[aaacccc, bbbdddd]],
|
| 140 |
+
[[aaaaacccccc, bbbbbdddddd],
|
| 141 |
+
[aaaaaaacccccccc, bbbbbbbdddddddd]]], dtype=object)
|
| 142 |
+
|
| 143 |
+
>>> np.tensordot(a, A, 0) # tensor product (result too long to incl.)
|
| 144 |
+
array([[[[[a, b],
|
| 145 |
+
[c, d]],
|
| 146 |
+
...
|
| 147 |
+
|
| 148 |
+
>>> np.tensordot(a, A, (0, 1))
|
| 149 |
+
array([[[abbbbb, cddddd],
|
| 150 |
+
[aabbbbbb, ccdddddd]],
|
| 151 |
+
[[aaabbbbbbb, cccddddddd],
|
| 152 |
+
[aaaabbbbbbbb, ccccdddddddd]]], dtype=object)
|
| 153 |
+
|
| 154 |
+
>>> np.tensordot(a, A, (2, 1))
|
| 155 |
+
array([[[abb, cdd],
|
| 156 |
+
[aaabbbb, cccdddd]],
|
| 157 |
+
[[aaaaabbbbbb, cccccdddddd],
|
| 158 |
+
[aaaaaaabbbbbbbb, cccccccdddddddd]]], dtype=object)
|
| 159 |
+
|
| 160 |
+
>>> np.tensordot(a, A, ((0, 1), (0, 1)))
|
| 161 |
+
array([abbbcccccddddddd, aabbbbccccccdddddddd], dtype=object)
|
| 162 |
+
|
| 163 |
+
>>> np.tensordot(a, A, ((2, 1), (1, 0)))
|
| 164 |
+
array([acccbbdddd, aaaaacccccccbbbbbbdddddddd], dtype=object)
|
| 165 |
+
|
| 166 |
+
"""
|
| 167 |
+
try:
|
| 168 |
+
iter(axes)
|
| 169 |
+
except:
|
| 170 |
+
axes_a = list(range(-axes, 0))
|
| 171 |
+
axes_b = list(range(0, axes))
|
| 172 |
+
else:
|
| 173 |
+
axes_a, axes_b = axes
|
| 174 |
+
try:
|
| 175 |
+
na = len(axes_a)
|
| 176 |
+
axes_a = list(axes_a)
|
| 177 |
+
except TypeError:
|
| 178 |
+
axes_a = [axes_a]
|
| 179 |
+
na = 1
|
| 180 |
+
try:
|
| 181 |
+
nb = len(axes_b)
|
| 182 |
+
axes_b = list(axes_b)
|
| 183 |
+
except TypeError:
|
| 184 |
+
axes_b = [axes_b]
|
| 185 |
+
nb = 1
|
| 186 |
+
|
| 187 |
+
a, b = asarray(a), asarray(b)
|
| 188 |
+
as_ = a.shape
|
| 189 |
+
nda = a.ndim
|
| 190 |
+
bs = b.shape
|
| 191 |
+
ndb = b.ndim
|
| 192 |
+
equal = True
|
| 193 |
+
if na != nb:
|
| 194 |
+
equal = False
|
| 195 |
+
else:
|
| 196 |
+
for k in range(na):
|
| 197 |
+
if as_[axes_a[k]] != bs[axes_b[k]]:
|
| 198 |
+
equal = False
|
| 199 |
+
break
|
| 200 |
+
if axes_a[k] < 0:
|
| 201 |
+
axes_a[k] += nda
|
| 202 |
+
if axes_b[k] < 0:
|
| 203 |
+
axes_b[k] += ndb
|
| 204 |
+
if not equal:
|
| 205 |
+
raise ValueError("shape-mismatch for sum")
|
| 206 |
+
|
| 207 |
+
# Move the axes to sum over to the end of "a"
|
| 208 |
+
# and to the front of "b"
|
| 209 |
+
notin = [k for k in range(nda) if k not in axes_a]
|
| 210 |
+
newaxes_a = notin + axes_a
|
| 211 |
+
N2 = 1
|
| 212 |
+
for axis in axes_a:
|
| 213 |
+
N2 *= as_[axis]
|
| 214 |
+
newshape_a = (-1, N2)
|
| 215 |
+
olda = [as_[axis] for axis in notin]
|
| 216 |
+
|
| 217 |
+
notin = [k for k in range(ndb) if k not in axes_b]
|
| 218 |
+
newaxes_b = axes_b + notin
|
| 219 |
+
N2 = 1
|
| 220 |
+
for axis in axes_b:
|
| 221 |
+
N2 *= bs[axis]
|
| 222 |
+
newshape_b = (N2, -1)
|
| 223 |
+
oldb = [bs[axis] for axis in notin]
|
| 224 |
+
|
| 225 |
+
at = a.transpose(newaxes_a).reshape(newshape_a)
|
| 226 |
+
bt = b.transpose(newaxes_b).reshape(newshape_b)
|
| 227 |
+
res = dot(at, bt)
|
| 228 |
+
return res.reshape(olda + oldb)
|
vendor/chumpy/chumpy/optimization.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
__all__ = ['minimize']
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from . import ch
|
| 13 |
+
import scipy.sparse as sp
|
| 14 |
+
import scipy.optimize
|
| 15 |
+
|
| 16 |
+
from .optimization_internal import minimize_dogleg
|
| 17 |
+
|
| 18 |
+
#from memory_profiler import profile, memory_usage
|
| 19 |
+
|
| 20 |
+
# def disable_cache_for_single_parent_node(node):
|
| 21 |
+
# if hasattr(node, '_parents') and len(node._parents.keys()) == 1:
|
| 22 |
+
# node.want_cache = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Nelder-Mead
|
| 26 |
+
# Powell
|
| 27 |
+
# CG
|
| 28 |
+
# BFGS
|
| 29 |
+
# Newton-CG
|
| 30 |
+
# Anneal
|
| 31 |
+
# L-BFGS-B
|
| 32 |
+
# TNC
|
| 33 |
+
# COBYLA
|
| 34 |
+
# SLSQP
|
| 35 |
+
# dogleg
|
| 36 |
+
# trust-ncg
|
| 37 |
+
def minimize(fun, x0, method='dogleg', bounds=None, constraints=(), tol=None, callback=None, options=None):
|
| 38 |
+
|
| 39 |
+
if method == 'dogleg':
|
| 40 |
+
if options is None: options = {}
|
| 41 |
+
return minimize_dogleg(fun, free_variables=x0, on_step=callback, **options)
|
| 42 |
+
|
| 43 |
+
if isinstance(fun, list) or isinstance(fun, tuple):
|
| 44 |
+
fun = ch.concatenate([f.ravel() for f in fun])
|
| 45 |
+
if isinstance(fun, dict):
|
| 46 |
+
fun = ch.concatenate([f.ravel() for f in list(fun.values())])
|
| 47 |
+
obj = fun
|
| 48 |
+
free_variables = x0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
from .ch import SumOfSquares
|
| 52 |
+
|
| 53 |
+
hessp = None
|
| 54 |
+
hess = None
|
| 55 |
+
if obj.size == 1:
|
| 56 |
+
obj_scalar = obj
|
| 57 |
+
else:
|
| 58 |
+
obj_scalar = SumOfSquares(obj)
|
| 59 |
+
|
| 60 |
+
def hessp(vs, p,obj, obj_scalar, free_variables):
|
| 61 |
+
changevars(vs,obj,obj_scalar,free_variables)
|
| 62 |
+
if not hasattr(hessp, 'vs'):
|
| 63 |
+
hessp.vs = vs*0+1e16
|
| 64 |
+
if np.max(np.abs(vs-hessp.vs)) > 0:
|
| 65 |
+
|
| 66 |
+
J = ns_jacfunc(vs,obj,obj_scalar,free_variables)
|
| 67 |
+
hessp.J = J
|
| 68 |
+
hessp.H = 2. * J.T.dot(J)
|
| 69 |
+
hessp.vs = vs
|
| 70 |
+
return np.array(hessp.H.dot(p)).ravel()
|
| 71 |
+
#return 2*np.array(hessp.J.T.dot(hessp.J.dot(p))).ravel()
|
| 72 |
+
|
| 73 |
+
if method.lower() != 'newton-cg':
|
| 74 |
+
def hess(vs, obj, obj_scalar, free_variables):
|
| 75 |
+
changevars(vs,obj,obj_scalar,free_variables)
|
| 76 |
+
if not hasattr(hessp, 'vs'):
|
| 77 |
+
hessp.vs = vs*0+1e16
|
| 78 |
+
if np.max(np.abs(vs-hessp.vs)) > 0:
|
| 79 |
+
J = ns_jacfunc(vs,obj,obj_scalar,free_variables)
|
| 80 |
+
hessp.H = 2. * J.T.dot(J)
|
| 81 |
+
return hessp.H
|
| 82 |
+
|
| 83 |
+
def changevars(vs, obj, obj_scalar, free_variables):
|
| 84 |
+
cur = 0
|
| 85 |
+
changed = False
|
| 86 |
+
for idx, freevar in enumerate(free_variables):
|
| 87 |
+
sz = freevar.r.size
|
| 88 |
+
newvals = vs[cur:cur+sz].copy().reshape(free_variables[idx].shape)
|
| 89 |
+
if np.max(np.abs(newvals-free_variables[idx]).ravel()) > 0:
|
| 90 |
+
free_variables[idx][:] = newvals
|
| 91 |
+
changed = True
|
| 92 |
+
|
| 93 |
+
cur += sz
|
| 94 |
+
|
| 95 |
+
methods_without_callback = ('anneal', 'powell', 'cobyla', 'slsqp')
|
| 96 |
+
if callback is not None and changed and method.lower() in methods_without_callback:
|
| 97 |
+
callback(None)
|
| 98 |
+
|
| 99 |
+
return changed
|
| 100 |
+
|
| 101 |
+
def residuals(vs,obj, obj_scalar, free_variables):
|
| 102 |
+
changevars(vs, obj, obj_scalar, free_variables)
|
| 103 |
+
residuals = obj_scalar.r.ravel()[0]
|
| 104 |
+
return residuals
|
| 105 |
+
|
| 106 |
+
def scalar_jacfunc(vs,obj, obj_scalar, free_variables):
|
| 107 |
+
if not hasattr(scalar_jacfunc, 'vs'):
|
| 108 |
+
scalar_jacfunc.vs = vs*0+1e16
|
| 109 |
+
if np.max(np.abs(vs-scalar_jacfunc.vs)) == 0:
|
| 110 |
+
return scalar_jacfunc.J
|
| 111 |
+
|
| 112 |
+
changevars(vs, obj, obj_scalar, free_variables)
|
| 113 |
+
|
| 114 |
+
if True: # faster, at least on some problems
|
| 115 |
+
result = np.concatenate([np.array(obj_scalar.lop(wrt, np.array([[1]]))).ravel() for wrt in free_variables])
|
| 116 |
+
else:
|
| 117 |
+
jacs = [obj_scalar.dr_wrt(wrt) for wrt in free_variables]
|
| 118 |
+
for idx, jac in enumerate(jacs):
|
| 119 |
+
if sp.issparse(jac):
|
| 120 |
+
jacs[idx] = jacs[idx].todense()
|
| 121 |
+
result = np.concatenate([jac.ravel() for jac in jacs])
|
| 122 |
+
|
| 123 |
+
scalar_jacfunc.J = result
|
| 124 |
+
scalar_jacfunc.vs = vs
|
| 125 |
+
return result.ravel()
|
| 126 |
+
|
| 127 |
+
def ns_jacfunc(vs,obj, obj_scalar, free_variables):
|
| 128 |
+
if not hasattr(ns_jacfunc, 'vs'):
|
| 129 |
+
ns_jacfunc.vs = vs*0+1e16
|
| 130 |
+
if np.max(np.abs(vs-ns_jacfunc.vs)) == 0:
|
| 131 |
+
return ns_jacfunc.J
|
| 132 |
+
|
| 133 |
+
changevars(vs, obj, obj_scalar, free_variables)
|
| 134 |
+
jacs = [obj.dr_wrt(wrt) for wrt in free_variables]
|
| 135 |
+
result = hstack(jacs)
|
| 136 |
+
|
| 137 |
+
ns_jacfunc.J = result
|
| 138 |
+
ns_jacfunc.vs = vs
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
x1 = scipy.optimize.minimize(
|
| 143 |
+
method=method,
|
| 144 |
+
fun=residuals,
|
| 145 |
+
callback=callback,
|
| 146 |
+
x0=np.concatenate([free_variable.r.ravel() for free_variable in free_variables]),
|
| 147 |
+
jac=scalar_jacfunc,
|
| 148 |
+
hessp=hessp, hess=hess, args=(obj, obj_scalar, free_variables),
|
| 149 |
+
bounds=bounds, constraints=constraints, tol=tol, options=options).x
|
| 150 |
+
|
| 151 |
+
changevars(x1, obj, obj_scalar, free_variables)
|
| 152 |
+
return free_variables
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def main():
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == '__main__':
|
| 160 |
+
main()
|
| 161 |
+
|
vendor/chumpy/chumpy/optimization_internal.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import warnings
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.sparse as sp
|
| 5 |
+
from . import ch, utils
|
| 6 |
+
from .ch import pif
|
| 7 |
+
from .utils import timer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def clear_cache_single(node):
|
| 11 |
+
node._cache['drs'].clear()
|
| 12 |
+
if hasattr(node, 'dr_cached'):
|
| 13 |
+
node.dr_cached.clear()
|
| 14 |
+
|
| 15 |
+
def vstack(x):
|
| 16 |
+
x = [a if not isinstance(a, sp.linalg.interface.LinearOperator) else a.dot(np.eye(a.shape[1])) for a in x]
|
| 17 |
+
return sp.vstack(x, format='csc') if any([sp.issparse(a) for a in x]) else np.vstack(x)
|
| 18 |
+
def hstack(x):
|
| 19 |
+
x = [a if not isinstance(a, sp.linalg.interface.LinearOperator) else a.dot(np.eye(a.shape[1])) for a in x]
|
| 20 |
+
return sp.hstack(x, format='csc') if any([sp.issparse(a) for a in x]) else np.hstack(x)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_giter = 0
|
| 24 |
+
class ChInputsStacked(ch.Ch):
|
| 25 |
+
dterms = 'x', 'obj'
|
| 26 |
+
terms = 'free_variables'
|
| 27 |
+
|
| 28 |
+
def compute_r(self):
|
| 29 |
+
if not hasattr(self, 'fevals'):
|
| 30 |
+
self.fevals = 0
|
| 31 |
+
self.fevals += 1
|
| 32 |
+
return self.obj.r.ravel()
|
| 33 |
+
|
| 34 |
+
def dr_wrt(self, wrt, profiler=None):
|
| 35 |
+
'''
|
| 36 |
+
Loop over free variables and delete cache for the whole tree after finished each one
|
| 37 |
+
'''
|
| 38 |
+
if wrt is self.x:
|
| 39 |
+
jacs = []
|
| 40 |
+
for fvi, freevar in enumerate(self.free_variables):
|
| 41 |
+
tm = timer()
|
| 42 |
+
if isinstance(freevar, ch.Select):
|
| 43 |
+
new_jac = self.obj.dr_wrt(freevar.a, profiler=profiler)
|
| 44 |
+
try:
|
| 45 |
+
new_jac = new_jac[:, freevar.idxs]
|
| 46 |
+
except:
|
| 47 |
+
# non-csc sparse matrices may not support column-wise indexing
|
| 48 |
+
new_jac = new_jac.tocsc()[:, freevar.idxs]
|
| 49 |
+
else:
|
| 50 |
+
new_jac = self.obj.dr_wrt(freevar, profiler=profiler)
|
| 51 |
+
|
| 52 |
+
pif('dx wrt {} in {}sec, sparse: {}'.format(freevar.short_name, tm(), sp.issparse(new_jac)))
|
| 53 |
+
|
| 54 |
+
if self._make_dense and sp.issparse(new_jac):
|
| 55 |
+
new_jac = new_jac.todense()
|
| 56 |
+
if self._make_sparse and not sp.issparse(new_jac):
|
| 57 |
+
new_jac = sp.csc_matrix(new_jac)
|
| 58 |
+
|
| 59 |
+
if new_jac is None:
|
| 60 |
+
raise Exception(
|
| 61 |
+
'Objective has no derivative wrt free variable {}. '
|
| 62 |
+
'You should likely remove it.'.format(fvi))
|
| 63 |
+
|
| 64 |
+
jacs.append(new_jac)
|
| 65 |
+
tm = timer()
|
| 66 |
+
utils.dfs_do_func_on_graph(self.obj, clear_cache_single)
|
| 67 |
+
pif('dfs_do_func_on_graph in {}sec'.format(tm()))
|
| 68 |
+
tm = timer()
|
| 69 |
+
J = hstack(jacs)
|
| 70 |
+
pif('hstack in {}sec'.format(tm()))
|
| 71 |
+
return J
|
| 72 |
+
|
| 73 |
+
def on_changed(self, which):
|
| 74 |
+
global _giter
|
| 75 |
+
_giter += 1
|
| 76 |
+
if 'x' in which:
|
| 77 |
+
pos = 0
|
| 78 |
+
for idx, freevar in enumerate(self.free_variables):
|
| 79 |
+
sz = freevar.r.size
|
| 80 |
+
rng = np.arange(pos, pos+sz)
|
| 81 |
+
if isinstance(self.free_variables[idx], ch.Select):
|
| 82 |
+
# Deal with nested selects
|
| 83 |
+
selects = []
|
| 84 |
+
a = self.free_variables[idx]
|
| 85 |
+
while isinstance(a, ch.Select):
|
| 86 |
+
selects.append(a.idxs)
|
| 87 |
+
a = a.a
|
| 88 |
+
newv = a.x.copy()
|
| 89 |
+
idxs = selects.pop()
|
| 90 |
+
while len(selects) > 0:
|
| 91 |
+
idxs = idxs[selects.pop()]
|
| 92 |
+
newv.ravel()[idxs] = self.x.r.ravel()[rng]
|
| 93 |
+
a.__setattr__('x', newv, _giter)
|
| 94 |
+
elif isinstance(self.free_variables[idx].x, np.ndarray):
|
| 95 |
+
self.free_variables[idx].__setattr__('x', self.x.r[rng].copy().reshape(self.free_variables[idx].x.shape), _giter)
|
| 96 |
+
else: # a number
|
| 97 |
+
self.free_variables[idx].__setattr__('x', self.x.r[rng], _giter)
|
| 98 |
+
pos += sz
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def J(self):
|
| 102 |
+
'''
|
| 103 |
+
Compute Jacobian. Analyze dr graph first to disable unnecessary caching
|
| 104 |
+
'''
|
| 105 |
+
result = self.dr_wrt(self.x, profiler=self.profiler).copy()
|
| 106 |
+
if self.profiler:
|
| 107 |
+
self.profiler.harvest()
|
| 108 |
+
return np.atleast_2d(result) if not sp.issparse(result) else result
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def setup_sparse_solver(sparse_solver):
|
| 112 |
+
_solver_fns = {
|
| 113 |
+
'cg': lambda A, x, M=None : sp.linalg.cg(A, x, M=M, tol=1e-10)[0],
|
| 114 |
+
'spsolve': lambda A, x : sp.linalg.spsolve(A, x)
|
| 115 |
+
}
|
| 116 |
+
if callable(sparse_solver):
|
| 117 |
+
return sparse_solver
|
| 118 |
+
elif isinstance(sparse_solver, str) and sparse_solver in list(_solver_fns.keys()):
|
| 119 |
+
return _solver_fns[sparse_solver]
|
| 120 |
+
else:
|
| 121 |
+
raise Exception('sparse_solver argument must be either a string in the set (%s) or have the api of scipy.sparse.linalg.spsolve.' % ', '.join(list(_solver_fns.keys())))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def setup_objective(obj, free_variables, on_step=None, disp=True, make_dense=False):
|
| 125 |
+
'''
|
| 126 |
+
obj here can be a list of ch objects or a dict of label: ch objects. Either way, the ch
|
| 127 |
+
objects will be merged into one objective using a ChInputsStacked. The labels are just used
|
| 128 |
+
for printing out values per objective with each iteration. If make_dense is True, the
|
| 129 |
+
resulting object with return a desne Jacobian
|
| 130 |
+
'''
|
| 131 |
+
# Validate free variables
|
| 132 |
+
num_unique_ids = len(np.unique(np.array([id(freevar) for freevar in free_variables])))
|
| 133 |
+
if num_unique_ids != len(free_variables):
|
| 134 |
+
raise Exception('The "free_variables" param contains duplicate variables.')
|
| 135 |
+
# Extract labels
|
| 136 |
+
labels = {}
|
| 137 |
+
if isinstance(obj, list) or isinstance(obj, tuple):
|
| 138 |
+
obj = ch.concatenate([f.ravel() for f in obj])
|
| 139 |
+
elif isinstance(obj, dict):
|
| 140 |
+
labels = obj
|
| 141 |
+
obj = ch.concatenate([f.ravel() for f in list(obj.values())])
|
| 142 |
+
# build objective
|
| 143 |
+
x = np.concatenate([freevar.r.ravel() for freevar in free_variables])
|
| 144 |
+
obj = ChInputsStacked(obj=obj, free_variables=free_variables, x=x, make_dense=make_dense)
|
| 145 |
+
# build callback
|
| 146 |
+
def callback():
|
| 147 |
+
if on_step is not None:
|
| 148 |
+
on_step(obj)
|
| 149 |
+
if disp:
|
| 150 |
+
report_line = ['%.2e' % (np.sum(obj.r**2),)]
|
| 151 |
+
for label, objective in sorted(list(labels.items()), key=lambda x: x[0]):
|
| 152 |
+
report_line.append('%s: %.2e' % (label, np.sum(objective.r**2)))
|
| 153 |
+
report_line = " | ".join(report_line) + '\n'
|
| 154 |
+
sys.stderr.write(report_line)
|
| 155 |
+
return obj, callback
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class DoglegState(object):
|
| 159 |
+
'''
|
| 160 |
+
Dogleg preserves a great deal of state from iteration to iteration. Many of the things
|
| 161 |
+
that we need to calculate are dependent only on this state (e.g. the various trust region
|
| 162 |
+
steps, the current jacobian and the A & g that depends on it, etc.). Holding the state and
|
| 163 |
+
the various methods based on that state here allows us to seperate a lot of the jacobian
|
| 164 |
+
based calculation from the flow control of the optmization.
|
| 165 |
+
|
| 166 |
+
There will be once instance of DoglegState per invocation of minimize_dogleg.
|
| 167 |
+
'''
|
| 168 |
+
def __init__(self, delta, solve):
|
| 169 |
+
self.iteration = 0
|
| 170 |
+
self._d_gn = None # gauss-newton
|
| 171 |
+
self._d_sd = None # steepest descent
|
| 172 |
+
self._d_dl = None # dogleg
|
| 173 |
+
self.J = None
|
| 174 |
+
self.A = None
|
| 175 |
+
self.g = None
|
| 176 |
+
self._p = None
|
| 177 |
+
self.delta = delta
|
| 178 |
+
self.solve = solve
|
| 179 |
+
self._r = None
|
| 180 |
+
self.rho = None
|
| 181 |
+
self.done = False
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def p(self):
|
| 185 |
+
'''p is the current proposed input vector'''
|
| 186 |
+
return self._p
|
| 187 |
+
@p.setter
|
| 188 |
+
def p(self, val):
|
| 189 |
+
self._p = val.reshape((-1, 1))
|
| 190 |
+
|
| 191 |
+
# induce some certainty about what the shape of the steps are
|
| 192 |
+
@property
|
| 193 |
+
def d_gn(self):
|
| 194 |
+
return self._d_gn
|
| 195 |
+
@d_gn.setter
|
| 196 |
+
def d_gn(self, val):
|
| 197 |
+
if val is not None:
|
| 198 |
+
val = val.reshape((-1, 1))
|
| 199 |
+
self._d_gn = val
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def d_sd(self):
|
| 203 |
+
return self._d_sd
|
| 204 |
+
@d_sd.setter
|
| 205 |
+
def d_sd(self, val):
|
| 206 |
+
if val is not None:
|
| 207 |
+
val = val.reshape((-1, 1))
|
| 208 |
+
self._d_sd = val
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def d_dl(self):
|
| 212 |
+
return self._d_dl
|
| 213 |
+
@d_dl.setter
|
| 214 |
+
def d_dl(self, val):
|
| 215 |
+
if val is not None:
|
| 216 |
+
val = val.reshape((-1, 1))
|
| 217 |
+
self._d_dl = val
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def step(self):
|
| 221 |
+
return self.d_dl.reshape((-1, 1))
|
| 222 |
+
@property
|
| 223 |
+
def step_size(self):
|
| 224 |
+
return np.linalg.norm(self.d_dl)
|
| 225 |
+
|
| 226 |
+
def start_iteration(self):
|
| 227 |
+
self.iteration += 1
|
| 228 |
+
pif('beginning iteration %d' % (self.iteration,))
|
| 229 |
+
self.d_sd = (np.linalg.norm(self.g)**2 / np.linalg.norm(self.J.dot(self.g))**2 * self.g).ravel()
|
| 230 |
+
self.d_gn = None
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def r(self):
|
| 234 |
+
'''r is the residual at the current p'''
|
| 235 |
+
return self._r
|
| 236 |
+
@r.setter
|
| 237 |
+
def r(self, val):
|
| 238 |
+
self._r = val.copy().reshape((-1, 1))
|
| 239 |
+
self.updateAg()
|
| 240 |
+
|
| 241 |
+
def updateAg(self):
|
| 242 |
+
tm = timer()
|
| 243 |
+
pif('updating A and g...')
|
| 244 |
+
JT = self.J.T
|
| 245 |
+
self.A = JT.dot(self.J)
|
| 246 |
+
self.g = JT.dot(-self.r).reshape((-1, 1))
|
| 247 |
+
pif('A and g updated in %.2fs' % tm())
|
| 248 |
+
|
| 249 |
+
def update_step(self):
|
| 250 |
+
# if the Cauchy point is outside the trust region,
|
| 251 |
+
# take that direction but only to the edge of the trust region
|
| 252 |
+
if self.delta is not None and np.linalg.norm(self.d_sd) >= self.delta:
|
| 253 |
+
pif('PROGRESS: Using stunted cauchy')
|
| 254 |
+
self.d_dl = np.array(self.delta/np.linalg.norm(self.d_sd) * self.d_sd).ravel()
|
| 255 |
+
else:
|
| 256 |
+
if self.d_gn is None:
|
| 257 |
+
# We only need to compute this once per iteration
|
| 258 |
+
self.updateGN()
|
| 259 |
+
# if the gauss-newton solution is within the trust region, use it
|
| 260 |
+
if self.delta is None or np.linalg.norm(self.d_gn) <= self.delta:
|
| 261 |
+
pif('PROGRESS: Using gauss-newton solution')
|
| 262 |
+
self.d_dl = np.array(self.d_gn).ravel()
|
| 263 |
+
if self.delta is None:
|
| 264 |
+
self.delta = np.linalg.norm(self.d_gn)
|
| 265 |
+
else: # between cauchy step and gauss-newton step
|
| 266 |
+
pif('PROGRESS: between cauchy and gauss-newton')
|
| 267 |
+
# apply step
|
| 268 |
+
self.d_dl = self.d_sd + self.beta_multiplier * (self.d_gn - self.d_sd)
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def beta_multiplier(self):
|
| 272 |
+
delta_sq = self.delta**2
|
| 273 |
+
diff = self.d_gn - self.d_sd
|
| 274 |
+
sqnorm_sd = np.linalg.norm(self.d_sd)**2
|
| 275 |
+
pnow = diff.T.dot(diff)*delta_sq + self.d_gn.T.dot(self.d_sd)**2 - np.linalg.norm(self.d_gn)**2 * sqnorm_sd
|
| 276 |
+
return float(delta_sq - sqnorm_sd) / float((diff).T.dot(self.d_sd) + np.sqrt(pnow))
|
| 277 |
+
|
| 278 |
+
def updateGN(self):
|
| 279 |
+
tm = timer()
|
| 280 |
+
if sp.issparse(self.A):
|
| 281 |
+
self.A.eliminate_zeros()
|
| 282 |
+
pif('sparse solve...sparsity infill is %.3f%% (hessian %dx%d)' % (100. * self.A.nnz / (self.A.shape[0] * self.A.shape[1]), self.A.shape[0], self.A.shape[1]))
|
| 283 |
+
if self.g.size > 1:
|
| 284 |
+
self.d_gn = self.solve(self.A, self.g).ravel()
|
| 285 |
+
if np.any(np.isnan(self.d_gn)) or np.any(np.isinf(self.d_gn)):
|
| 286 |
+
from scipy.sparse.linalg import lsqr
|
| 287 |
+
warnings.warn("sparse solve failed, falling back to lsqr")
|
| 288 |
+
self.d_gn = lsqr(self.A, self.g)[0].ravel()
|
| 289 |
+
else:
|
| 290 |
+
self.d_gn = np.atleast_1d(self.g.ravel()[0]/self.A[0,0])
|
| 291 |
+
pif('sparse solve...done in %.2fs' % tm())
|
| 292 |
+
else:
|
| 293 |
+
pif('dense solve...')
|
| 294 |
+
try:
|
| 295 |
+
self.d_gn = np.linalg.solve(self.A, self.g).ravel()
|
| 296 |
+
except Exception:
|
| 297 |
+
warnings.warn("dense solve failed, falling back to lsqr")
|
| 298 |
+
self.d_gn = np.linalg.lstsq(self.A, self.g)[0].ravel()
|
| 299 |
+
pif('dense solve...done in %.2fs' % tm())
|
| 300 |
+
|
| 301 |
+
def updateJ(self, obj):
|
| 302 |
+
tm = timer()
|
| 303 |
+
pif('computing Jacobian...')
|
| 304 |
+
self.J = obj.J
|
| 305 |
+
if self.J is None:
|
| 306 |
+
raise Exception("Computing Jacobian failed!")
|
| 307 |
+
if sp.issparse(self.J):
|
| 308 |
+
tm2 = timer()
|
| 309 |
+
self.J = self.J.tocsr()
|
| 310 |
+
pif('converted to csr in {}secs'.format(tm2()))
|
| 311 |
+
assert(self.J.nnz > 0)
|
| 312 |
+
elif ch.VERBOSE:
|
| 313 |
+
nonzero = np.count_nonzero(self.J)
|
| 314 |
+
pif('Jacobian dense with sparsity %.3f' % (nonzero/self.J.size))
|
| 315 |
+
pif('Jacobian (%dx%d) computed in %.2fs' % (self.J.shape[0], self.J.shape[1], tm()))
|
| 316 |
+
if self.J.shape[1] != self.p.size:
|
| 317 |
+
raise Exception('Jacobian size mismatch with objective input')
|
| 318 |
+
return self.J
|
| 319 |
+
|
| 320 |
+
class Trial(object):
|
| 321 |
+
'''
|
| 322 |
+
Inside each iteration of dogleg we propose a step and check to see if it's actually
|
| 323 |
+
an improvement before we accept it. This class encapsulates that trial and the
|
| 324 |
+
testing to see if it is actually an improvement.
|
| 325 |
+
|
| 326 |
+
There will be one instance of Trial per iteration in dogleg.
|
| 327 |
+
'''
|
| 328 |
+
def __init__(self, proposed_r, state):
|
| 329 |
+
self.r = proposed_r
|
| 330 |
+
self.state = state
|
| 331 |
+
# rho is the ratio of...
|
| 332 |
+
# (improvement in SSE) / (predicted improvement in SSE)
|
| 333 |
+
self.rho = np.linalg.norm(state.r)**2 - np.linalg.norm(proposed_r)**2
|
| 334 |
+
if self.rho > 0:
|
| 335 |
+
with warnings.catch_warnings():
|
| 336 |
+
warnings.filterwarnings('ignore',category=RuntimeWarning)
|
| 337 |
+
predicted_improvement = 2. * state.g.T.dot(state.d_dl) - state.d_dl.T.dot(state.A.dot(state.d_dl))
|
| 338 |
+
self.rho /= predicted_improvement
|
| 339 |
+
|
| 340 |
+
@property
|
| 341 |
+
def is_improvement(self):
|
| 342 |
+
return self.rho > 0
|
| 343 |
+
|
| 344 |
+
@property
|
| 345 |
+
def improvement(self):
|
| 346 |
+
return (np.linalg.norm(self.state.r)**2 - np.linalg.norm(self.r)**2) / np.linalg.norm(self.state.r)**2
|
| 347 |
+
|
| 348 |
+
def trial_r(self, proposed_r):
|
| 349 |
+
return self.Trial(proposed_r, self)
|
| 350 |
+
|
| 351 |
+
def updateRadius(self, rho, lb=.05, ub=.9):
|
| 352 |
+
if rho > ub:
|
| 353 |
+
self.delta = max(self.delta, 2.5*np.linalg.norm(self.d_dl))
|
| 354 |
+
elif rho < lb:
|
| 355 |
+
self.delta *= .25
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def minimize_dogleg(obj, free_variables, on_step=None,
|
| 359 |
+
maxiter=200, max_fevals=np.inf, sparse_solver='spsolve',
|
| 360 |
+
disp=True, e_1=1e-15, e_2=1e-15, e_3=0., delta_0=None,
|
| 361 |
+
treat_as_dense=False):
|
| 362 |
+
""""Nonlinear optimization using Powell's dogleg method.
|
| 363 |
+
See Lourakis et al, 2005, ICCV '05, "Is Levenberg-Marquardt the
|
| 364 |
+
Most Efficient Optimization for Implementing Bundle Adjustment?":
|
| 365 |
+
http://www.ics.forth.gr/cvrl/publications/conferences/0201-P0401-lourakis-levenberg.pdf
|
| 366 |
+
|
| 367 |
+
e_N are stopping conditions:
|
| 368 |
+
e_1 is gradient magnatude threshold
|
| 369 |
+
e_2 is step size magnatude threshold
|
| 370 |
+
e_3 is improvement threshold (as a ratio; 0.1 means it must improve by 10%% at each step)
|
| 371 |
+
|
| 372 |
+
maxiter and max_fevals are also stopping conditions. Note that they're not quite the same,
|
| 373 |
+
as an iteration may evaluate the function more than once.
|
| 374 |
+
|
| 375 |
+
sparse_solver is the solver to use to calculate the Gauss-Newton step in the common case
|
| 376 |
+
that the Jacobian is sparse. It can be 'spsolve' (in which case scipy.sparse.linalg.spsolve
|
| 377 |
+
will be used), 'cg' (in which case scipy.sparse.linalg.cg will be used), or any callable
|
| 378 |
+
that matches the api of scipy.sparse.linalg.spsolve to solve `A x = b` for x where A is sparse.
|
| 379 |
+
|
| 380 |
+
cg, uses a Conjugate Gradient method, and will be faster if A is sparse but x is dense.
|
| 381 |
+
spsolve will be faster if x is also sparse.
|
| 382 |
+
|
| 383 |
+
delta_0 defines the initial trust region. Generally speaking, if this is set too low then
|
| 384 |
+
the optimization will never really go anywhere (to small a trust region to make any real
|
| 385 |
+
progress before running out of iterations) and if it's set too high then the optimization
|
| 386 |
+
will diverge immidiately and go wild (such a large trust region that the initial step so
|
| 387 |
+
far overshoots that it can't recover). If it's left as None, it will be automatically
|
| 388 |
+
estimated on the first iteration; it's always updated at each iteration, so this is treated
|
| 389 |
+
only as an initialization.
|
| 390 |
+
|
| 391 |
+
handle_as_dense explicitly converts all Jacobians of obj to dense matrices
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
solve = setup_sparse_solver(sparse_solver)
|
| 396 |
+
obj, callback = setup_objective(obj, free_variables, on_step=on_step, disp=disp,
|
| 397 |
+
make_dense=treat_as_dense)
|
| 398 |
+
|
| 399 |
+
state = DoglegState(delta=delta_0, solve=solve)
|
| 400 |
+
state.p = obj.x.r
|
| 401 |
+
|
| 402 |
+
#inject profiler if in DEBUG mode
|
| 403 |
+
if ch.DEBUG:
|
| 404 |
+
from .monitor import DrWrtProfiler
|
| 405 |
+
obj.profiler = DrWrtProfiler(obj)
|
| 406 |
+
|
| 407 |
+
callback()
|
| 408 |
+
state.updateJ(obj)
|
| 409 |
+
state.r = obj.r
|
| 410 |
+
|
| 411 |
+
def stop(msg):
|
| 412 |
+
if not state.done:
|
| 413 |
+
pif(msg)
|
| 414 |
+
state.done = True
|
| 415 |
+
|
| 416 |
+
if np.linalg.norm(state.g, np.inf) < e_1:
|
| 417 |
+
stop('stopping because norm(g, np.inf) < %.2e' % e_1)
|
| 418 |
+
while not state.done:
|
| 419 |
+
state.start_iteration()
|
| 420 |
+
while True:
|
| 421 |
+
state.update_step()
|
| 422 |
+
if state.step_size <= e_2 * np.linalg.norm(state.p):
|
| 423 |
+
stop('stopping because of small step size (norm_dl < %.2e)' % (e_2 * np.linalg.norm(state.p)))
|
| 424 |
+
else:
|
| 425 |
+
tm = timer()
|
| 426 |
+
obj.x = state.p + state.step
|
| 427 |
+
trial = state.trial_r(obj.r)
|
| 428 |
+
pif('Residuals computed in %.2fs' % tm())
|
| 429 |
+
# if the objective function improved, update input parameter estimate.
|
| 430 |
+
# Note that the obj.x already has the new parms,
|
| 431 |
+
# and we should not set them again to the same (or we'll bust the cache)
|
| 432 |
+
if trial.is_improvement:
|
| 433 |
+
state.p = state.p + state.step
|
| 434 |
+
callback()
|
| 435 |
+
if e_3 > 0. and trial.improvement < e_3:
|
| 436 |
+
stop('stopping because improvement < %.1e%%' % (100*e_3))
|
| 437 |
+
else:
|
| 438 |
+
state.updateJ(obj)
|
| 439 |
+
state.r = trial.r
|
| 440 |
+
if np.linalg.norm(state.g, np.inf) < e_1:
|
| 441 |
+
stop('stopping because norm(g, np.inf) < %.2e' % e_1)
|
| 442 |
+
else: # Put the old parms back
|
| 443 |
+
obj.x = ch.Ch(state.p)
|
| 444 |
+
obj.on_changed('x') # copies from flat vector to free variables
|
| 445 |
+
# update our trust region
|
| 446 |
+
state.updateRadius(trial.rho)
|
| 447 |
+
if state.delta <= e_2*np.linalg.norm(state.p):
|
| 448 |
+
stop('stopping because trust region is too small')
|
| 449 |
+
if state.done or trial.is_improvement or (obj.fevals >= max_fevals):
|
| 450 |
+
break
|
| 451 |
+
if state.iteration >= maxiter:
|
| 452 |
+
stop('stopping because max number of user-specified iterations (%d) has been met' % maxiter)
|
| 453 |
+
elif obj.fevals >= max_fevals:
|
| 454 |
+
stop('stopping because max number of user-specified func evals (%d) has been met' % max_fevals)
|
| 455 |
+
return obj.free_variables
|
vendor/chumpy/chumpy/optional_test_performance.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import unittest
|
| 10 |
+
import numpy as np
|
| 11 |
+
from functools import reduce
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
has_ressources = True
|
| 15 |
+
try:
|
| 16 |
+
import resource
|
| 17 |
+
|
| 18 |
+
def abstract_ressource_timer():
|
| 19 |
+
return resource.getrusage(resource.RUSAGE_SELF)
|
| 20 |
+
def abstract_ressource_counter(r1, r2):
|
| 21 |
+
_r1 = r1.ru_stime + r1.ru_utime
|
| 22 |
+
_r2 = r2.ru_stime + r2.ru_utime
|
| 23 |
+
|
| 24 |
+
return _r2 - _r1
|
| 25 |
+
except ImportError:
|
| 26 |
+
has_ressources = False
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if not has_ressources:
|
| 31 |
+
try:
|
| 32 |
+
from ctypes import *
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def abstract_ressource_timer():
|
| 37 |
+
val = c_int64()
|
| 38 |
+
windll.Kernel32.QueryPerformanceCounter(byref(val))
|
| 39 |
+
return val
|
| 40 |
+
def abstract_ressource_counter(r1, r2):
|
| 41 |
+
"""Returns the elapsed time between r2 and r1 (r2 > r1) in milliseconds"""
|
| 42 |
+
val = c_int64()
|
| 43 |
+
windll.Kernel32.QueryPerformanceFrequency(byref(val))
|
| 44 |
+
|
| 45 |
+
return (1000*float(r2.value-r1.value))/val.value
|
| 46 |
+
|
| 47 |
+
except ImportError:
|
| 48 |
+
has_win32api = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
from . import ch
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Timer(object):
|
| 58 |
+
|
| 59 |
+
def __enter__(self):
|
| 60 |
+
self.r1 = abstract_ressource_timer()
|
| 61 |
+
|
| 62 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
| 63 |
+
self.r2 = abstract_ressource_timer()
|
| 64 |
+
|
| 65 |
+
self.elapsed = abstract_ressource_counter(self.r1, self.r2)
|
| 66 |
+
|
| 67 |
+
# def timer():
|
| 68 |
+
# tm = resource.getrusage(resource.RUSAGE_SELF)
|
| 69 |
+
# return tm.ru_stime + tm.ru_utime
|
| 70 |
+
#
|
| 71 |
+
# svd1
|
| 72 |
+
|
| 73 |
+
def timer(setup, go, n):
|
| 74 |
+
tms = []
|
| 75 |
+
for i in range(n):
|
| 76 |
+
if setup is not None:
|
| 77 |
+
setup()
|
| 78 |
+
|
| 79 |
+
tm0 = abstract_ressource_timer()
|
| 80 |
+
|
| 81 |
+
# if False:
|
| 82 |
+
# from body.misc.profdot import profdot
|
| 83 |
+
# profdot('go()', globals(), locals())
|
| 84 |
+
# import pdb; pdb.set_trace()
|
| 85 |
+
|
| 86 |
+
go()
|
| 87 |
+
tm1 = abstract_ressource_timer()
|
| 88 |
+
|
| 89 |
+
tms.append(abstract_ressource_counter(tm0, tm1))
|
| 90 |
+
|
| 91 |
+
#raw_input(tms)
|
| 92 |
+
return np.mean(tms) # see docs for timeit, which recommend getting minimum
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
import timeit
|
| 97 |
+
class TestPerformance(unittest.TestCase):
|
| 98 |
+
|
| 99 |
+
def setUp(self):
|
| 100 |
+
np.random.seed(0)
|
| 101 |
+
self.mtx_10 = ch.array(np.random.randn(100).reshape((10,10)))
|
| 102 |
+
self.mtx_1k = ch.array(np.random.randn(1000000).reshape((1000,1000)))
|
| 103 |
+
|
| 104 |
+
def compute_binary_ratios(self, vecsize, numvecs):
|
| 105 |
+
|
| 106 |
+
ratio = {}
|
| 107 |
+
for funcname in ['add', 'subtract', 'multiply', 'divide', 'power']:
|
| 108 |
+
for xp in ch, np:
|
| 109 |
+
func = getattr(xp, funcname)
|
| 110 |
+
vecs = [xp.random.rand(vecsize) for i in range(numvecs)]
|
| 111 |
+
|
| 112 |
+
if xp is ch:
|
| 113 |
+
f = reduce(lambda x, y : func(x,y), vecs)
|
| 114 |
+
def go():
|
| 115 |
+
for v in vecs:
|
| 116 |
+
v.x *= -1
|
| 117 |
+
_ = f.r
|
| 118 |
+
|
| 119 |
+
tm_ch = timer(None, go, 10)
|
| 120 |
+
else: # xp is np
|
| 121 |
+
def go():
|
| 122 |
+
for v in vecs:
|
| 123 |
+
v *= -1
|
| 124 |
+
_ = reduce(lambda x, y : func(x,y), vecs)
|
| 125 |
+
|
| 126 |
+
tm_np = timer(None, go, 10)
|
| 127 |
+
|
| 128 |
+
ratio[funcname] = tm_ch / tm_np
|
| 129 |
+
|
| 130 |
+
return ratio
|
| 131 |
+
|
| 132 |
+
def test_binary_ratios(self):
|
| 133 |
+
ratios = self.compute_binary_ratios(vecsize=5000, numvecs=100)
|
| 134 |
+
tol = 1e-1
|
| 135 |
+
self.assertLess(ratios['add'], 8+tol)
|
| 136 |
+
self.assertLess(ratios['subtract'], 8+tol)
|
| 137 |
+
self.assertLess(ratios['multiply'], 8+tol)
|
| 138 |
+
self.assertLess(ratios['divide'], 4+tol)
|
| 139 |
+
self.assertLess(ratios['power'], 2+tol)
|
| 140 |
+
#print ratios
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_svd(self):
|
| 144 |
+
mtx = ch.array(np.random.randn(100).reshape((10,10)))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# Get times for svd
|
| 148 |
+
from .linalg import svd
|
| 149 |
+
u, s, v = svd(mtx)
|
| 150 |
+
def setup():
|
| 151 |
+
mtx.x = -mtx.x
|
| 152 |
+
|
| 153 |
+
def go_r():
|
| 154 |
+
_ = u.r
|
| 155 |
+
_ = s.r
|
| 156 |
+
_ = v.r
|
| 157 |
+
|
| 158 |
+
def go_dr():
|
| 159 |
+
_ = u.dr_wrt(mtx)
|
| 160 |
+
_ = s.dr_wrt(mtx)
|
| 161 |
+
_ = v.dr_wrt(mtx)
|
| 162 |
+
|
| 163 |
+
cht_r = timer(setup, go_r, 20)
|
| 164 |
+
cht_dr = timer(setup, go_dr, 1)
|
| 165 |
+
|
| 166 |
+
# Get times for numpy svd
|
| 167 |
+
def go():
|
| 168 |
+
u,s,v = np.linalg.svd(mtx.x)
|
| 169 |
+
npt = timer(setup = None, go = go, n = 20)
|
| 170 |
+
|
| 171 |
+
# Compare
|
| 172 |
+
#print cht_r / npt
|
| 173 |
+
#print cht_dr / npt
|
| 174 |
+
self.assertLess(cht_r / npt, 3.3)
|
| 175 |
+
self.assertLess(cht_dr / npt, 2700)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == '__main__':
|
| 182 |
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestPerformance)
|
| 183 |
+
unittest.TextTestRunner(verbosity=2).run(suite)
|
vendor/chumpy/chumpy/reordering.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author(s): Matthew Loper
|
| 3 |
+
|
| 4 |
+
See LICENCE.txt for licensing and contact information.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .ch import Ch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from .utils import row, col
|
| 10 |
+
import scipy.sparse as sp
|
| 11 |
+
import weakref
|
| 12 |
+
|
| 13 |
+
__all__ = ['sort', 'tile', 'repeat', 'transpose', 'rollaxis', 'swapaxes', 'reshape', 'Select',
|
| 14 |
+
'atleast_1d', 'atleast_2d', 'atleast_3d', 'squeeze', 'expand_dims', 'fliplr', 'flipud',
|
| 15 |
+
'concatenate', 'vstack', 'hstack', 'dstack', 'ravel', 'diag', 'diagflat', 'roll', 'rot90']
|
| 16 |
+
|
| 17 |
+
# Classes deriving from "Permute" promise to only reorder/reshape
|
| 18 |
+
class Permute(Ch):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
def ravel(a, order='C'):
|
| 22 |
+
assert(order=='C')
|
| 23 |
+
if isinstance (a, np.ndarray):
|
| 24 |
+
self = Ch(a)
|
| 25 |
+
|
| 26 |
+
return reshape(a=a, newshape=(-1,))
|
| 27 |
+
|
| 28 |
+
class Reorder(Permute):
|
| 29 |
+
dterms = 'a',
|
| 30 |
+
|
| 31 |
+
def on_changed(self, which):
|
| 32 |
+
if not hasattr(self, 'dr_lookup'):
|
| 33 |
+
self.dr_lookup = {}
|
| 34 |
+
|
| 35 |
+
def compute_r(self):
|
| 36 |
+
return self.reorder(self.a.r)
|
| 37 |
+
|
| 38 |
+
def compute_dr_wrt(self, wrt):
|
| 39 |
+
if wrt is self.a:
|
| 40 |
+
if False:
|
| 41 |
+
from scipy.sparse.linalg.interface import LinearOperator
|
| 42 |
+
return LinearOperator((self.size, wrt.size), lambda x : self.reorder(x.reshape(self.a.shape)).ravel())
|
| 43 |
+
else:
|
| 44 |
+
a = self.a
|
| 45 |
+
asz = a.size
|
| 46 |
+
ashape = a.shape
|
| 47 |
+
key = self.unique_reorder_id()
|
| 48 |
+
if key not in self.dr_lookup or key is None:
|
| 49 |
+
JS = self.reorder(np.arange(asz).reshape(ashape))
|
| 50 |
+
IS = np.arange(JS.size)
|
| 51 |
+
data = np.ones_like(IS)
|
| 52 |
+
shape = JS.shape
|
| 53 |
+
self.dr_lookup[key] = sp.csc_matrix((data, (IS, JS.ravel())), shape=(self.r.size, wrt.r.size))
|
| 54 |
+
return self.dr_lookup[key]
|
| 55 |
+
|
| 56 |
+
class Sort(Reorder):
|
| 57 |
+
dterms = 'a'
|
| 58 |
+
terms = 'axis', 'kind', 'order'
|
| 59 |
+
|
| 60 |
+
def reorder(self, a): return np.sort(a, self.axis, self.kind, self.order)
|
| 61 |
+
def unique_reorder_id(self): return None
|
| 62 |
+
|
| 63 |
+
def sort(a, axis=-1, kind='quicksort', order=None):
|
| 64 |
+
return Sort(a=a, axis=axis, kind=kind, order=order)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Tile(Reorder):
|
| 68 |
+
dterms = 'a',
|
| 69 |
+
terms = 'reps',
|
| 70 |
+
term_order = 'a', 'reps'
|
| 71 |
+
|
| 72 |
+
def reorder(self, a): return np.tile(a, self.reps)
|
| 73 |
+
def unique_reorder_id(self): return (self.a.shape, tuple(self.reps))
|
| 74 |
+
|
| 75 |
+
def tile(A, reps):
|
| 76 |
+
return Tile(a=A, reps=reps)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Diag(Reorder):
|
| 80 |
+
dterms = 'a',
|
| 81 |
+
terms = 'k',
|
| 82 |
+
|
| 83 |
+
def reorder(self, a): return np.diag(a, self.k)
|
| 84 |
+
def unique_reorder_id(self): return (self.a.shape, self.k)
|
| 85 |
+
|
| 86 |
+
def diag(v, k=0):
|
| 87 |
+
return Diag(a=v, k=k)
|
| 88 |
+
|
| 89 |
+
class DiagFlat(Reorder):
|
| 90 |
+
dterms = 'a',
|
| 91 |
+
terms = 'k',
|
| 92 |
+
|
| 93 |
+
def reorder(self, a): return np.diagflat(a, self.k)
|
| 94 |
+
def unique_reorder_id(self): return (self.a.shape, self.k)
|
| 95 |
+
|
| 96 |
+
def diagflat(v, k=0):
|
| 97 |
+
return DiagFlat(a=v, k=k)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Repeat(Reorder):
|
| 101 |
+
dterms = 'a',
|
| 102 |
+
terms = 'repeats', 'axis'
|
| 103 |
+
|
| 104 |
+
def reorder(self, a): return np.repeat(a, self.repeats, self.axis)
|
| 105 |
+
def unique_reorder_id(self): return (self.repeats, self.axis)
|
| 106 |
+
|
| 107 |
+
def repeat(a, repeats, axis=None):
|
| 108 |
+
return Repeat(a=a, repeats=repeats, axis=axis)
|
| 109 |
+
|
| 110 |
+
class transpose(Reorder):
|
| 111 |
+
dterms = 'a'
|
| 112 |
+
terms = 'axes'
|
| 113 |
+
term_order = 'a', 'axes'
|
| 114 |
+
|
| 115 |
+
def reorder(self, a): return np.require(np.transpose(a, axes=self.axes), requirements='C')
|
| 116 |
+
def unique_reorder_id(self): return (self.a.shape, None if self.axes is None else tuple(self.axes))
|
| 117 |
+
def on_changed(self, which):
|
| 118 |
+
if not hasattr(self, 'axes'):
|
| 119 |
+
self.axes = None
|
| 120 |
+
super(self.__class__, self).on_changed(which)
|
| 121 |
+
|
| 122 |
+
class rollaxis(Reorder):
|
| 123 |
+
dterms = 'a'
|
| 124 |
+
terms = 'axis', 'start'
|
| 125 |
+
term_order = 'a', 'axis', 'start'
|
| 126 |
+
|
| 127 |
+
def reorder(self, a): return np.rollaxis(a, axis=self.axis, start=self.start)
|
| 128 |
+
def unique_reorder_id(self): return (self.a.shape, self.axis, self.start)
|
| 129 |
+
def on_changed(self, which):
|
| 130 |
+
if not hasattr(self, 'start'):
|
| 131 |
+
self.start = 0
|
| 132 |
+
super(self.__class__, self).on_changed(which)
|
| 133 |
+
|
| 134 |
+
class swapaxes(Reorder):
|
| 135 |
+
dterms = 'a'
|
| 136 |
+
terms = 'axis1', 'axis2'
|
| 137 |
+
term_order = 'a', 'axis1', 'axis2'
|
| 138 |
+
|
| 139 |
+
def reorder(self, a): return np.swapaxes(a, axis1=self.axis1, axis2=self.axis2)
|
| 140 |
+
def unique_reorder_id(self): return (self.a.shape, self.axis1, self.axis2)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Roll(Reorder):
|
| 145 |
+
dterms = 'a',
|
| 146 |
+
terms = 'shift', 'axis'
|
| 147 |
+
term_order = 'a', 'shift', 'axis'
|
| 148 |
+
|
| 149 |
+
def reorder(self, a): return np.roll(a, self.shift, self.axis)
|
| 150 |
+
def unique_reorder_id(self): return (self.shift, self.axis)
|
| 151 |
+
|
| 152 |
+
def roll(a, shift, axis=None):
|
| 153 |
+
return Roll(a, shift, axis)
|
| 154 |
+
|
| 155 |
+
class Rot90(Reorder):
|
| 156 |
+
dterms = 'a',
|
| 157 |
+
terms = 'k',
|
| 158 |
+
|
| 159 |
+
def reorder(self, a): return np.rot90(a, self.k)
|
| 160 |
+
def unique_reorder_id(self): return (self.a.shape, self.k)
|
| 161 |
+
|
| 162 |
+
def rot90(m, k=1):
|
| 163 |
+
return Rot90(a=m, k=k)
|
| 164 |
+
|
| 165 |
+
class Reshape(Permute):
|
| 166 |
+
dterms = 'a',
|
| 167 |
+
terms = 'newshape',
|
| 168 |
+
term_order= 'a', 'newshape'
|
| 169 |
+
|
| 170 |
+
def compute_r(self):
|
| 171 |
+
return self.a.r.reshape(self.newshape)
|
| 172 |
+
|
| 173 |
+
def compute_dr_wrt(self, wrt):
|
| 174 |
+
if wrt is self.a:
|
| 175 |
+
return sp.eye(self.a.size, self.a.size)
|
| 176 |
+
#return self.a.dr_wrt(wrt)
|
| 177 |
+
|
| 178 |
+
# def reshape(a, newshape):
|
| 179 |
+
# if isinstance(a, Reshape) and a.newshape == newshape:
|
| 180 |
+
# return a
|
| 181 |
+
# return Reshape(a=a, newshape=newshape)
|
| 182 |
+
def reshape(a, newshape):
|
| 183 |
+
while isinstance(a, Reshape):
|
| 184 |
+
a = a.a
|
| 185 |
+
return Reshape(a=a, newshape=newshape)
|
| 186 |
+
|
| 187 |
+
# class And(Ch):
|
| 188 |
+
# dterms = 'x1', 'x2'
|
| 189 |
+
#
|
| 190 |
+
# def compute_r(self):
|
| 191 |
+
# if True:
|
| 192 |
+
# needs_work = [self.x1, self.x2]
|
| 193 |
+
# done = []
|
| 194 |
+
# while len(needs_work) > 0:
|
| 195 |
+
# todo = needs_work.pop()
|
| 196 |
+
# if isinstance(todo, And):
|
| 197 |
+
# needs_work += [todo.x1, todo.x2]
|
| 198 |
+
# else:
|
| 199 |
+
# done = [todo] + done
|
| 200 |
+
# return np.concatenate([d.r.ravel() for d in done])
|
| 201 |
+
# else:
|
| 202 |
+
# return np.concatenate((self.x1.r.ravel(), self.x2.r.ravel()))
|
| 203 |
+
#
|
| 204 |
+
# # This is only here for reverse mode to work.
|
| 205 |
+
# # Most of the time, the overridden dr_wrt is callpath gets used.
|
| 206 |
+
# def compute_dr_wrt(self, wrt):
|
| 207 |
+
#
|
| 208 |
+
# if wrt is not self.x1 and wrt is not self.x2:
|
| 209 |
+
# return
|
| 210 |
+
#
|
| 211 |
+
# input_len = wrt.r.size
|
| 212 |
+
# x1_len = self.x1.r.size
|
| 213 |
+
# x2_len = self.x2.r.size
|
| 214 |
+
#
|
| 215 |
+
# mtxs = []
|
| 216 |
+
# if wrt is self.x1:
|
| 217 |
+
# mtxs.append(sp.spdiags(np.ones(x1_len), 0, x1_len, x1_len))
|
| 218 |
+
# else:
|
| 219 |
+
# mtxs.append(sp.csc_matrix((x1_len, input_len)))
|
| 220 |
+
#
|
| 221 |
+
# if wrt is self.x2:
|
| 222 |
+
# mtxs.append(sp.spdiags(np.ones(x2_len), 0, x2_len, x2_len))
|
| 223 |
+
# else:
|
| 224 |
+
# mtxs.append(sp.csc_matrix((x2_len, input_len)))
|
| 225 |
+
#
|
| 226 |
+
#
|
| 227 |
+
# if any([sp.issparse(mtx) for mtx in mtxs]):
|
| 228 |
+
# result = sp.vstack(mtxs, format='csc')
|
| 229 |
+
# else:
|
| 230 |
+
# result = np.vstack(mtxs)
|
| 231 |
+
#
|
| 232 |
+
# return result
|
| 233 |
+
#
|
| 234 |
+
# def dr_wrt(self, wrt, want_stacks=False, reverse_mode=False):
|
| 235 |
+
# self._call_on_changed()
|
| 236 |
+
#
|
| 237 |
+
# input_len = wrt.r.size
|
| 238 |
+
# x1_len = self.x1.r.size
|
| 239 |
+
# x2_len = self.x2.r.size
|
| 240 |
+
#
|
| 241 |
+
# mtxs = []
|
| 242 |
+
# if wrt is self.x1:
|
| 243 |
+
# mtxs.append(sp.spdiags(np.ones(x1_len), 0, x1_len, x1_len))
|
| 244 |
+
# else:
|
| 245 |
+
# if isinstance(self.x1, And):
|
| 246 |
+
# tmp_mtxs = self.x1.dr_wrt(wrt, want_stacks=True, reverse_mode=reverse_mode)
|
| 247 |
+
# for mtx in tmp_mtxs:
|
| 248 |
+
# mtxs.append(mtx)
|
| 249 |
+
# else:
|
| 250 |
+
# mtxs.append(self.x1.dr_wrt(wrt, reverse_mode=reverse_mode))
|
| 251 |
+
# if mtxs[-1] is None:
|
| 252 |
+
# mtxs[-1] = sp.csc_matrix((x1_len, input_len))
|
| 253 |
+
#
|
| 254 |
+
# if wrt is self.x2:
|
| 255 |
+
# mtxs.append(sp.spdiags(np.ones(x2_len), 0, x2_len, x2_len))
|
| 256 |
+
# else:
|
| 257 |
+
# if isinstance(self.x2, And):
|
| 258 |
+
# tmp_mtxs = self.x2.dr_wrt(wrt, want_stacks=True, reverse_mode=reverse_mode)
|
| 259 |
+
# for mtx in tmp_mtxs:
|
| 260 |
+
# mtxs.append(mtx)
|
| 261 |
+
# else:
|
| 262 |
+
# mtxs.append(self.x2.dr_wrt(wrt, reverse_mode=reverse_mode))
|
| 263 |
+
# if mtxs[-1] is None:
|
| 264 |
+
# mtxs[-1] = sp.csc_matrix((x2_len, input_len))
|
| 265 |
+
#
|
| 266 |
+
# if want_stacks:
|
| 267 |
+
# return mtxs
|
| 268 |
+
# else:
|
| 269 |
+
# if any([sp.issparse(mtx) for mtx in mtxs]):
|
| 270 |
+
# result = sp.vstack(mtxs, format='csc')
|
| 271 |
+
# else:
|
| 272 |
+
# result = np.vstack(mtxs)
|
| 273 |
+
#
|
| 274 |
+
# return result
|
| 275 |
+
|
| 276 |
+
class Select(Permute):
|
| 277 |
+
terms = ['idxs', 'preferred_shape']
|
| 278 |
+
dterms = ['a']
|
| 279 |
+
term_order = 'a', 'idxs', 'preferred_shape'
|
| 280 |
+
|
| 281 |
+
def compute_r(self):
|
| 282 |
+
result = self.a.r.ravel()[self.idxs].copy()
|
| 283 |
+
if hasattr(self, 'preferred_shape'):
|
| 284 |
+
return result.reshape(self.preferred_shape)
|
| 285 |
+
else:
|
| 286 |
+
return result
|
| 287 |
+
|
| 288 |
+
def compute_dr_wrt(self, obj):
|
| 289 |
+
if obj is self.a:
|
| 290 |
+
if not hasattr(self, '_dr_cached'):
|
| 291 |
+
IS = np.arange(len(self.idxs))
|
| 292 |
+
JS = self.idxs.ravel()
|
| 293 |
+
ij = np.vstack((row(IS), row(JS)))
|
| 294 |
+
data = np.ones(len(self.idxs))
|
| 295 |
+
self._dr_cached = sp.csc_matrix((data, ij), shape=(len(self.idxs), np.prod(self.a.shape)))
|
| 296 |
+
return self._dr_cached
|
| 297 |
+
|
| 298 |
+
def on_changed(self, which):
|
| 299 |
+
if hasattr(self, '_dr_cached'):
|
| 300 |
+
if 'idxs' in which or self.a.r.size != self._dr_cached.shape[1]:
|
| 301 |
+
del self._dr_cached
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class AtleastNd(Ch):
|
| 306 |
+
dterms = 'x'
|
| 307 |
+
terms = 'ndims'
|
| 308 |
+
|
| 309 |
+
def compute_r(self):
|
| 310 |
+
xr = self.x.r
|
| 311 |
+
if self.ndims == 1:
|
| 312 |
+
target_shape = np.atleast_1d(xr).shape
|
| 313 |
+
elif self.ndims == 2:
|
| 314 |
+
target_shape = np.atleast_2d(xr).shape
|
| 315 |
+
elif self.ndims == 3:
|
| 316 |
+
target_shape = np.atleast_3d(xr).shape
|
| 317 |
+
else:
|
| 318 |
+
raise Exception('Need ndims to be 1, 2, or 3.')
|
| 319 |
+
|
| 320 |
+
return xr.reshape(target_shape)
|
| 321 |
+
|
| 322 |
+
def compute_dr_wrt(self, wrt):
|
| 323 |
+
if wrt is self.x:
|
| 324 |
+
return 1
|
| 325 |
+
|
| 326 |
+
def atleast_nd(ndims, *arys):
|
| 327 |
+
arys = [AtleastNd(x=ary, ndims=ndims) for ary in arys]
|
| 328 |
+
return arys if len(arys) > 1 else arys[0]
|
| 329 |
+
|
| 330 |
+
def atleast_1d(*arys):
|
| 331 |
+
return atleast_nd(1, *arys)
|
| 332 |
+
|
| 333 |
+
def atleast_2d(*arys):
|
| 334 |
+
return atleast_nd(2, *arys)
|
| 335 |
+
|
| 336 |
+
def atleast_3d(*arys):
|
| 337 |
+
return atleast_nd(3, *arys)
|
| 338 |
+
|
| 339 |
+
def squeeze(a, axis=None):
|
| 340 |
+
if isinstance(a, np.ndarray):
|
| 341 |
+
return np.squeeze(a, axis)
|
| 342 |
+
shape = np.squeeze(a.r, axis).shape
|
| 343 |
+
return a.reshape(shape)
|
| 344 |
+
|
| 345 |
+
def expand_dims(a, axis):
|
| 346 |
+
if isinstance(a, np.ndarray):
|
| 347 |
+
return np.expand_dims(a, axis)
|
| 348 |
+
shape = np.expand_dims(a.r, axis).shape
|
| 349 |
+
return a.reshape(shape)
|
| 350 |
+
|
| 351 |
+
def fliplr(m):
|
| 352 |
+
return m[:,::-1]
|
| 353 |
+
|
| 354 |
+
def flipud(m):
|
| 355 |
+
return m[::-1,...]
|
| 356 |
+
|
| 357 |
+
class Concatenate(Ch):
|
| 358 |
+
|
| 359 |
+
def on_changed(self, which):
|
| 360 |
+
if not hasattr(self, 'dr_cached'):
|
| 361 |
+
self.dr_cached = weakref.WeakKeyDictionary()
|
| 362 |
+
|
| 363 |
+
@property
|
| 364 |
+
def our_terms(self):
|
| 365 |
+
if not hasattr(self, '_our_terms'):
|
| 366 |
+
self._our_terms = [getattr(self, s) for s in self.dterms]
|
| 367 |
+
return self._our_terms
|
| 368 |
+
|
| 369 |
+
def __getstate__(self):
|
| 370 |
+
# Have to get rid of WeakKeyDictionaries for serialization
|
| 371 |
+
if hasattr(self, 'dr_cached'):
|
| 372 |
+
del self.dr_cached
|
| 373 |
+
return super(self.__class__, self).__getstate__()
|
| 374 |
+
|
| 375 |
+
def compute_r(self):
|
| 376 |
+
return np.concatenate([t.r for t in self.our_terms], axis=self.axis)
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def everything(self):
|
| 380 |
+
if not hasattr(self, '_everything'):
|
| 381 |
+
self._everything = np.arange(self.r.size).reshape(self.r.shape)
|
| 382 |
+
self._everything = np.swapaxes(self._everything, self.axis, 0)
|
| 383 |
+
return self._everything
|
| 384 |
+
|
| 385 |
+
def compute_dr_wrt(self, wrt):
|
| 386 |
+
if not hasattr(self, 'dr_cached'):
|
| 387 |
+
self.dr_cached = weakref.WeakKeyDictionary()
|
| 388 |
+
if wrt in self.dr_cached and self.dr_cached[wrt] is not None:
|
| 389 |
+
return self.dr_cached[wrt]
|
| 390 |
+
|
| 391 |
+
if wrt not in self.our_terms:
|
| 392 |
+
return
|
| 393 |
+
|
| 394 |
+
_JS = np.arange(wrt.size)
|
| 395 |
+
_data = np.ones(wrt.size)
|
| 396 |
+
|
| 397 |
+
IS = []
|
| 398 |
+
JS = []
|
| 399 |
+
data = []
|
| 400 |
+
|
| 401 |
+
offset = 0
|
| 402 |
+
for term in self.our_terms:
|
| 403 |
+
tsz = term.shape[self.axis]
|
| 404 |
+
if term is wrt:
|
| 405 |
+
JS += [_JS]
|
| 406 |
+
data += [_data]
|
| 407 |
+
IS += [np.swapaxes(self.everything[offset:offset+tsz], self.axis, 0).ravel()]
|
| 408 |
+
offset += tsz
|
| 409 |
+
IS = np.concatenate(IS).ravel()
|
| 410 |
+
JS = np.concatenate(JS).ravel()
|
| 411 |
+
data = np.concatenate(data)
|
| 412 |
+
|
| 413 |
+
res = sp.csc_matrix((data, (IS, JS)), shape=(self.r.size, wrt.size))
|
| 414 |
+
|
| 415 |
+
if len(list(self._parents.keys())) != 1:
|
| 416 |
+
self.dr_cached[wrt] = res
|
| 417 |
+
else:
|
| 418 |
+
self.dr_cached[wrt] = None
|
| 419 |
+
|
| 420 |
+
return res
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def expand_concatenates(mtxs, axis=0):
|
| 424 |
+
mtxs = list(mtxs)
|
| 425 |
+
done = []
|
| 426 |
+
while len(mtxs) > 0:
|
| 427 |
+
mtx = mtxs.pop(0)
|
| 428 |
+
if isinstance(mtx, Concatenate) and mtx.axis == axis:
|
| 429 |
+
mtxs = [getattr(mtx, s) for s in mtx.dterms] + mtxs
|
| 430 |
+
else:
|
| 431 |
+
done.append(mtx)
|
| 432 |
+
return done
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def concatenate(mtxs, axis=0, **kwargs):
|
| 436 |
+
|
| 437 |
+
mtxs = expand_concatenates(mtxs, axis)
|
| 438 |
+
|
| 439 |
+
result = Concatenate(**kwargs)
|
| 440 |
+
result.dterms = []
|
| 441 |
+
for i, mtx in enumerate(mtxs):
|
| 442 |
+
result.dterms.append('m%d' % (i,))
|
| 443 |
+
setattr(result, result.dterms[-1], mtx)
|
| 444 |
+
result.axis = axis
|
| 445 |
+
return result
|
| 446 |
+
|
| 447 |
+
def hstack(mtxs, **kwargs):
|
| 448 |
+
return concatenate(mtxs, axis=1, **kwargs)
|
| 449 |
+
|
| 450 |
+
def vstack(mtxs, **kwargs):
|
| 451 |
+
return concatenate([atleast_2d(m) for m in mtxs], axis=0, **kwargs)
|
| 452 |
+
|
| 453 |
+
def dstack(mtxs, **kwargs):
|
| 454 |
+
return concatenate([atleast_3d(m) for m in mtxs], axis=2, **kwargs)
|
vendor/chumpy/chumpy/test_ch.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
import unittest
|
| 12 |
+
import numpy as np
|
| 13 |
+
import scipy.sparse as sp
|
| 14 |
+
|
| 15 |
+
from . import ch
|
| 16 |
+
|
| 17 |
+
class TestCh(unittest.TestCase):
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_cachehits(self):
|
| 21 |
+
"""Test how many nodes are visited when cache is cleared.
|
| 22 |
+
If the number of hits changes, it has to be carefully
|
| 23 |
+
looked at to make sure that correctness and performance
|
| 24 |
+
don't get messed up by a change."""
|
| 25 |
+
|
| 26 |
+
a = ch.array(1)
|
| 27 |
+
b = ch.array(2)
|
| 28 |
+
c = a
|
| 29 |
+
for i in range(10):
|
| 30 |
+
c = a + c + b
|
| 31 |
+
|
| 32 |
+
c.dr_wrt(a)
|
| 33 |
+
c.dr_wrt(b)
|
| 34 |
+
self.assertEqual(a.clear_cache() + b.clear_cache(), 59)
|
| 35 |
+
c.dr_wrt(a)
|
| 36 |
+
c.dr_wrt(b)
|
| 37 |
+
self.assertEqual(a.clear_cache(123) + b.clear_cache(123), 41)
|
| 38 |
+
|
| 39 |
+
def test_nested_concatenate(self):
|
| 40 |
+
aa = ch.arange(3)
|
| 41 |
+
bb = ch.arange(4)
|
| 42 |
+
cc = ch.arange(5)
|
| 43 |
+
|
| 44 |
+
result = ch.concatenate((ch.concatenate((aa,bb)),cc))
|
| 45 |
+
self.assertTrue(result.m0 is aa)
|
| 46 |
+
self.assertTrue(result.m1 is bb)
|
| 47 |
+
self.assertTrue(result.m2 is cc)
|
| 48 |
+
|
| 49 |
+
self.assertTrue(result.dr_wrt(aa).nnz > 0)
|
| 50 |
+
self.assertTrue(result.dr_wrt(bb).nnz > 0)
|
| 51 |
+
self.assertTrue(result.dr_wrt(cc).nnz > 0)
|
| 52 |
+
|
| 53 |
+
def test_nandivide(self):
|
| 54 |
+
foo = ch.array(np.random.randn(16).reshape((4,4)))
|
| 55 |
+
bar = ch.array(np.random.randn(16).reshape((4,4)))
|
| 56 |
+
bar[2,2] = 0
|
| 57 |
+
self.assertEqual(ch.NanDivide(foo,bar)[2,2].r, 0.)
|
| 58 |
+
foo[2,2] = 0
|
| 59 |
+
self.assertEqual(ch.NanDivide(foo,bar)[2,2].r, 0.)
|
| 60 |
+
|
| 61 |
+
def test_casting(self):
|
| 62 |
+
for fn in float, int:
|
| 63 |
+
self.assertEqual(fn(np.array(5)), fn(ch.array(5)))
|
| 64 |
+
self.assertEqual(fn(np.array([[5]])), fn(ch.array([[5]])))
|
| 65 |
+
|
| 66 |
+
def test_tensordot(self):
|
| 67 |
+
an = np.arange(60.).reshape(3,4,5)
|
| 68 |
+
bn = np.arange(24.).reshape(4,3,2)
|
| 69 |
+
cn = np.tensordot(an,bn, axes=([1,0],[0,1]))
|
| 70 |
+
|
| 71 |
+
ac = ch.arange(60.).reshape(3,4,5)
|
| 72 |
+
bc = ch.arange(24.).reshape(4,3,2)
|
| 73 |
+
cc = ch.tensordot(ac,bc, axes=([1,0],[0,1]))
|
| 74 |
+
|
| 75 |
+
cc.r
|
| 76 |
+
cc.dr_wrt(ac)
|
| 77 |
+
cc.dr_wrt(bc)
|
| 78 |
+
#print cn
|
| 79 |
+
|
| 80 |
+
def test_make_sure_is_double(self):
|
| 81 |
+
x = ch.array([0])
|
| 82 |
+
self.assertTrue(isinstance(x.r[0], np.float64))
|
| 83 |
+
|
| 84 |
+
def test_cross(self):
|
| 85 |
+
aa = ch.random.randn(30).reshape((10,3))
|
| 86 |
+
bb = ch.random.randn(30).reshape((10,3))
|
| 87 |
+
|
| 88 |
+
cross_ch = ch.cross(aa, bb)
|
| 89 |
+
cross_np = np.cross(aa.r, bb.r)
|
| 90 |
+
|
| 91 |
+
# print cross_ch.r
|
| 92 |
+
# print cross_np
|
| 93 |
+
|
| 94 |
+
eps = 1.0
|
| 95 |
+
step = (np.random.rand(30) - .5).reshape((10,3)) * eps
|
| 96 |
+
|
| 97 |
+
gt_diff = np.cross(aa.r, bb.r+step) - cross_np
|
| 98 |
+
pr_diff = cross_ch.dr_wrt(bb).dot(step.ravel())
|
| 99 |
+
# print gt_diff
|
| 100 |
+
# print pr_diff
|
| 101 |
+
# print np.max(np.abs(gt_diff.ravel()-pr_diff.ravel()))
|
| 102 |
+
self.assertTrue(1e-14 > np.max(np.abs(gt_diff.ravel()-pr_diff.ravel())))
|
| 103 |
+
|
| 104 |
+
gt_diff = np.cross(aa.r+step, bb.r) - cross_np
|
| 105 |
+
pr_diff = cross_ch.dr_wrt(aa).dot(step.ravel())
|
| 106 |
+
#print gt_diff
|
| 107 |
+
# print pr_diff
|
| 108 |
+
# print np.max(np.abs(gt_diff.ravel()-pr_diff.ravel()))
|
| 109 |
+
self.assertTrue(1e-14 > np.max(np.abs(gt_diff.ravel()-pr_diff.ravel())))
|
| 110 |
+
|
| 111 |
+
def test_dr_wrt_selection(self):
|
| 112 |
+
aa = ch.arange(10,20)
|
| 113 |
+
bb = ch.arange(1,11)
|
| 114 |
+
cc = aa * bb + aa + bb +2
|
| 115 |
+
|
| 116 |
+
dr0 = cc.dr_wrt(aa[4:6])
|
| 117 |
+
dr1 = cc.dr_wrt(aa)[:,4:6]
|
| 118 |
+
self.assertTrue((dr0 - dr1).nnz == 0)
|
| 119 |
+
|
| 120 |
+
dr0 = cc.dr_wrt(bb[5:8])
|
| 121 |
+
dr1 = cc.dr_wrt(bb)[:,5:8]
|
| 122 |
+
self.assertTrue((dr0 - dr1).nnz == 0)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def test_sum_mean_std_var(self):
|
| 126 |
+
for fn in [ch.sum, ch.mean, ch.var, ch.std]:
|
| 127 |
+
|
| 128 |
+
# Create fake input and differences in input space
|
| 129 |
+
data1 = ch.ones((3,4,7,2))
|
| 130 |
+
data2 = ch.array(data1.r + .1 * np.random.rand(data1.size).reshape(data1.shape))
|
| 131 |
+
diff = data2.r - data1.r
|
| 132 |
+
|
| 133 |
+
# Compute outputs
|
| 134 |
+
result1 = fn(data1, axis=2)
|
| 135 |
+
result2 = fn(data2, axis=2)
|
| 136 |
+
|
| 137 |
+
# Empirical and predicted derivatives
|
| 138 |
+
gt = result2.r - result1.r
|
| 139 |
+
pred = result1.dr_wrt(data1).dot(diff.ravel()).reshape(gt.shape)
|
| 140 |
+
|
| 141 |
+
#print np.max(np.abs(gt - pred))
|
| 142 |
+
|
| 143 |
+
if fn in [ch.std, ch.var]:
|
| 144 |
+
self.assertTrue(1e-2 > np.max(np.abs(gt - pred)))
|
| 145 |
+
else:
|
| 146 |
+
self.assertTrue(1e-14 > np.max(np.abs(gt - pred)))
|
| 147 |
+
# test caching
|
| 148 |
+
dr0 = result1.dr_wrt(data1)
|
| 149 |
+
data1[:] = np.random.randn(data1.size).reshape(data1.shape)
|
| 150 |
+
self.assertTrue(result1.dr_wrt(data1) is dr0) # changing values shouldn't force recompute
|
| 151 |
+
result1.axis=1
|
| 152 |
+
self.assertTrue(result1.dr_wrt(data1) is not dr0)
|
| 153 |
+
|
| 154 |
+
self.assertEqual(ch.mean(ch.eye(3),axis=1).ndim, np.mean(np.eye(3),axis=1).ndim)
|
| 155 |
+
self.assertEqual(ch.mean(ch.eye(3),axis=0).ndim, np.mean(np.eye(3),axis=0).ndim)
|
| 156 |
+
self.assertEqual(ch.sum(ch.eye(3),axis=1).ndim, np.sum(np.eye(3),axis=1).ndim)
|
| 157 |
+
self.assertEqual(ch.sum(ch.eye(3),axis=0).ndim, np.sum(np.eye(3),axis=0).ndim)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def test_cumsum(self):
|
| 162 |
+
a = ch.array([1.,5.,3.,7.])
|
| 163 |
+
cs = ch.cumsum(a)
|
| 164 |
+
r1 = cs.r
|
| 165 |
+
dr = cs.dr_wrt(a)
|
| 166 |
+
diff = (ch.random.rand(4)-.5)*.1
|
| 167 |
+
a.x += diff.r
|
| 168 |
+
pred = dr.dot(diff.r)
|
| 169 |
+
gt = cs.r - r1
|
| 170 |
+
self.assertTrue(1e-13 > np.max(np.abs(gt - pred)))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def test_iteration_cache(self):
|
| 174 |
+
""" Each time you set an attribute, the cache (of r's and dr's) of
|
| 175 |
+
ancestors is cleared. Because children share ancestors, this means
|
| 176 |
+
these can be cleared multiple times unnecessarily; in some cases,
|
| 177 |
+
where lots of objects exist, this cache clearing can actually be a bottleneck.
|
| 178 |
+
|
| 179 |
+
Therefore, the concept of an iteration was added; intended to be used in
|
| 180 |
+
an optimization setting (see optimization.py) and in the set() method, it
|
| 181 |
+
avoids such redundant clearing of cache."""
|
| 182 |
+
|
| 183 |
+
a, b, c = ch.Ch(1), ch.Ch(2), ch.Ch(3)
|
| 184 |
+
x = a+b
|
| 185 |
+
y = x+c
|
| 186 |
+
self.assertTrue(y.r[0]==6)
|
| 187 |
+
|
| 188 |
+
a.__setattr__('x', 10, 1)
|
| 189 |
+
self.assertTrue(y.r == 15)
|
| 190 |
+
a.__setattr__('x', 100, 1)
|
| 191 |
+
self.assertTrue(y.r == 15)
|
| 192 |
+
a.__setattr__('x', 100, 2)
|
| 193 |
+
self.assertTrue(y.r == 105)
|
| 194 |
+
|
| 195 |
+
a, b, c = ch.array([1]), ch.array([2]), ch.array([3])
|
| 196 |
+
x = a+b
|
| 197 |
+
y = x+c
|
| 198 |
+
self.assertTrue(y.r[0]==6)
|
| 199 |
+
|
| 200 |
+
a.__setattr__('x', np.array([10]), 1)
|
| 201 |
+
self.assertTrue(y.r[0] == 15)
|
| 202 |
+
a.__setattr__('x', np.array(100), 1)
|
| 203 |
+
self.assertTrue(y.r[0] == 15)
|
| 204 |
+
a.__setattr__('x', np.array(100), 2)
|
| 205 |
+
self.assertTrue(y.r[0] == 105)
|
| 206 |
+
a.__setitem__(list(range(0,1)), np.array(200), 2)
|
| 207 |
+
self.assertTrue(y.r[0] == 105)
|
| 208 |
+
a.__setitem__(list(range(0,1)), np.array(200), 3)
|
| 209 |
+
self.assertTrue(y.r[0] == 205)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def test_stacking(self):
|
| 214 |
+
|
| 215 |
+
a1 = ch.Ch(np.arange(10).reshape(2,5))
|
| 216 |
+
b1 = ch.Ch(np.arange(20).reshape(4,5))
|
| 217 |
+
c1 = ch.vstack((a1,b1))
|
| 218 |
+
c1_check = np.vstack((a1.r, b1.r))
|
| 219 |
+
residuals1 = (c1_check - c1.r).ravel()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
a2 = ch.Ch(np.arange(10).reshape(5,2))
|
| 223 |
+
b2 = ch.Ch(np.arange(20).reshape(5,4))
|
| 224 |
+
c2 = ch.hstack((a2,b2))
|
| 225 |
+
c2_check = np.hstack((a2.r, b2.r))
|
| 226 |
+
residuals2 = (c2_check - c2.r).ravel()
|
| 227 |
+
|
| 228 |
+
self.assertFalse(np.any(residuals1))
|
| 229 |
+
self.assertFalse(np.any(residuals2))
|
| 230 |
+
|
| 231 |
+
d0 = ch.array(np.arange(60).reshape((10,6)))
|
| 232 |
+
d1 = ch.vstack((d0[:4], d0[4:]))
|
| 233 |
+
d2 = ch.hstack((d1[:,:3], d1[:,3:]))
|
| 234 |
+
tmp = d2.dr_wrt(d0).todense()
|
| 235 |
+
diff = tmp - np.eye(tmp.shape[0])
|
| 236 |
+
self.assertFalse(np.any(diff.ravel()))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
#def test_drs(self):
|
| 241 |
+
# a = ch.Ch(2)
|
| 242 |
+
# b = ch.Ch(3)
|
| 243 |
+
# c = a * b
|
| 244 |
+
# print c.dr_wrt(a)
|
| 245 |
+
# print c.compute_drs_wrt(a).r
|
| 246 |
+
|
| 247 |
+
@unittest.skip('We are using LinearOperator for this for now. Might change back though.')
|
| 248 |
+
def test_reorder_caching(self):
|
| 249 |
+
a = ch.Ch(np.zeros(8).reshape((4,2)))
|
| 250 |
+
b = a.T
|
| 251 |
+
dr0 = b.dr_wrt(a)
|
| 252 |
+
a.x = a.x + 1.
|
| 253 |
+
dr1 = b.dr_wrt(a)
|
| 254 |
+
self.assertTrue(dr0 is dr1)
|
| 255 |
+
a.x = np.zeros(4).reshape((2,2))
|
| 256 |
+
dr2 = b.dr_wrt(a)
|
| 257 |
+
self.assertTrue(dr2 is not dr1)
|
| 258 |
+
|
| 259 |
+
def test_transpose(self):
|
| 260 |
+
from .utils import row, col
|
| 261 |
+
from copy import deepcopy
|
| 262 |
+
for which in ('C', 'F'): # test in fortran and contiguous mode
|
| 263 |
+
a = ch.Ch(np.require(np.zeros(8).reshape((4,2)), requirements=which))
|
| 264 |
+
b = a.T
|
| 265 |
+
|
| 266 |
+
b1 = b.r.copy()
|
| 267 |
+
#dr = b.dr_wrt(a).copy()
|
| 268 |
+
dr = deepcopy(b.dr_wrt(a))
|
| 269 |
+
|
| 270 |
+
diff = np.arange(a.size).reshape(a.shape)
|
| 271 |
+
a.x = np.require(a.r + diff, requirements=which)
|
| 272 |
+
b2 = b.r.copy()
|
| 273 |
+
|
| 274 |
+
diff_pred = dr.dot(col(diff)).ravel()
|
| 275 |
+
diff_emp = (b2 - b1).ravel()
|
| 276 |
+
np.testing.assert_array_equal(diff_pred, diff_emp)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def test_unary(self):
|
| 280 |
+
fns = [ch.exp, ch.log, ch.sin, ch.arcsin, ch.cos, ch.arccos, ch.tan, ch.arctan, ch.negative, ch.square, ch.sqrt, ch.abs, ch.reciprocal]
|
| 281 |
+
|
| 282 |
+
eps = 1e-8
|
| 283 |
+
for f in fns:
|
| 284 |
+
|
| 285 |
+
x0 = ch.Ch(.25)
|
| 286 |
+
x1 = ch.Ch(x0.r+eps)
|
| 287 |
+
|
| 288 |
+
pred = f(x0).dr_wrt(x0)
|
| 289 |
+
empr = (f(x1).r - f(x0).r) / eps
|
| 290 |
+
|
| 291 |
+
# print pred
|
| 292 |
+
# print empr
|
| 293 |
+
if f is ch.reciprocal:
|
| 294 |
+
self.assertTrue(1e-6 > np.abs(pred.ravel()[0] - empr.ravel()[0]))
|
| 295 |
+
else:
|
| 296 |
+
self.assertTrue(1e-7 > np.abs(pred.ravel()[0] - empr.ravel()[0]))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def test_serialization(self):
|
| 300 |
+
# The main challenge with serialization is the "_parents"
|
| 301 |
+
# attribute, which is a nonserializable WeakKeyDictionary.
|
| 302 |
+
# So we pickle/unpickle, change a child and verify the value
|
| 303 |
+
# at root, and verify that both children have parentage.
|
| 304 |
+
from six.moves import cPickle as pickle
|
| 305 |
+
tmp = ch.Ch(10) + ch.Ch(20)
|
| 306 |
+
tmp = pickle.loads(pickle.dumps(tmp))
|
| 307 |
+
tmp.b.x = 30
|
| 308 |
+
self.assertTrue(tmp.r[0] == 40)
|
| 309 |
+
self.assertTrue(list(tmp.a._parents.keys())[0] == tmp)
|
| 310 |
+
self.assertTrue(list(tmp.a._parents.keys())[0] == list(tmp.b._parents.keys())[0])
|
| 311 |
+
|
| 312 |
+
def test_chlambda1(self):
|
| 313 |
+
c1, c2, c3 = ch.Ch(1), ch.Ch(2), ch.Ch(3)
|
| 314 |
+
adder = ch.ChLambda(lambda x, y: x+y)
|
| 315 |
+
adder.x = c1
|
| 316 |
+
adder.y = c2
|
| 317 |
+
self.assertTrue(adder.r == 3)
|
| 318 |
+
adder.x = c2
|
| 319 |
+
self.assertTrue(adder.r == 4)
|
| 320 |
+
adder.x = c1
|
| 321 |
+
self.assertTrue(adder.r == 3)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def test_chlambda2(self):
|
| 325 |
+
passthrough = ch.ChLambda( lambda x : x)
|
| 326 |
+
self.assertTrue(passthrough.dr_wrt(passthrough.x) is not None)
|
| 327 |
+
passthrough.x = ch.Ch(123)
|
| 328 |
+
self.assertTrue(passthrough.dr_wrt(passthrough.x) is not None)
|
| 329 |
+
|
| 330 |
+
# It's probably not reasonable to expect this
|
| 331 |
+
# to work for ChLambda
|
| 332 |
+
#def test_chlambda3(self):
|
| 333 |
+
# c1, c2, c3 = ch.Ch(1), ch.Ch(2), ch.Ch(3)
|
| 334 |
+
# triple = ch.ChLambda( lambda x, y, z : x(y, z))
|
| 335 |
+
# triple.x = Add
|
| 336 |
+
# triple.y = c2
|
| 337 |
+
# triple.z = c3
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def test_amax(self):
|
| 344 |
+
from .ch import amax
|
| 345 |
+
import numpy as np
|
| 346 |
+
arr = np.empty((5,2,3,7))
|
| 347 |
+
arr.flat[:] = np.sin(np.arange(arr.size)*1000.)
|
| 348 |
+
#arr = np.array(np.sin(np.arange(24)*10000.).reshape(2,3,4))
|
| 349 |
+
|
| 350 |
+
for axis in range(len(arr.shape)):
|
| 351 |
+
a = amax(a=arr, axis=axis)
|
| 352 |
+
pred = a.dr_wrt(a.a).dot(arr.ravel())
|
| 353 |
+
real = np.amax(arr, axis=axis).ravel()
|
| 354 |
+
self.assertTrue(np.max(np.abs(pred-real)) < 1e-10)
|
| 355 |
+
|
| 356 |
+
def test_maximum(self):
|
| 357 |
+
from .utils import row, col
|
| 358 |
+
from .ch import maximum
|
| 359 |
+
|
| 360 |
+
# Make sure that when we compare the max of two *identical* numbers,
|
| 361 |
+
# we get the right derivatives wrt both
|
| 362 |
+
the_max = maximum(ch.Ch(1), ch.Ch(1))
|
| 363 |
+
self.assertTrue(the_max.r.ravel()[0] == 1.)
|
| 364 |
+
self.assertTrue(the_max.dr_wrt(the_max.a)[0,0] == 1.)
|
| 365 |
+
self.assertTrue(the_max.dr_wrt(the_max.b)[0,0] == 1.)
|
| 366 |
+
|
| 367 |
+
# Now test given that all numbers are different, by allocating from
|
| 368 |
+
# a pool of randomly permuted numbers.
|
| 369 |
+
# We test combinations of scalars and 2d arrays.
|
| 370 |
+
rnd = np.asarray(np.random.permutation(np.arange(20)), np.float64)
|
| 371 |
+
c1 = ch.Ch(rnd[:6].reshape((2,3)))
|
| 372 |
+
c2 = ch.Ch(rnd[6:12].reshape((2,3)))
|
| 373 |
+
s1 = ch.Ch(rnd[12])
|
| 374 |
+
s2 = ch.Ch(rnd[13])
|
| 375 |
+
|
| 376 |
+
eps = .1
|
| 377 |
+
for first in [c1, s1]:
|
| 378 |
+
for second in [c2, s2]:
|
| 379 |
+
the_max = maximum(first, second)
|
| 380 |
+
|
| 381 |
+
for which_to_change in [first, second]:
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
max_r0 = the_max.r.copy()
|
| 385 |
+
max_r_diff = np.max(np.abs(max_r0 - np.maximum(first.r, second.r)))
|
| 386 |
+
self.assertTrue(max_r_diff == 0)
|
| 387 |
+
max_dr = the_max.dr_wrt(which_to_change).copy()
|
| 388 |
+
which_to_change.x = which_to_change.x + eps
|
| 389 |
+
max_r1 = the_max.r.copy()
|
| 390 |
+
|
| 391 |
+
emp_diff = (the_max.r - max_r0).ravel()
|
| 392 |
+
pred_diff = max_dr.dot(col(eps*np.ones(max_dr.shape[1]))).ravel()
|
| 393 |
+
|
| 394 |
+
#print 'comparing the following numbers/vectors:'
|
| 395 |
+
#print first.r
|
| 396 |
+
#print second.r
|
| 397 |
+
#print 'empirical vs predicted difference:'
|
| 398 |
+
#print emp_diff
|
| 399 |
+
#print pred_diff
|
| 400 |
+
#print '-----'
|
| 401 |
+
|
| 402 |
+
max_dr_diff = np.max(np.abs(emp_diff-pred_diff))
|
| 403 |
+
#print 'max dr diff: %.2e' % (max_dr_diff,)
|
| 404 |
+
self.assertTrue(max_dr_diff < 1e-14)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def test_shared(self):
|
| 408 |
+
|
| 409 |
+
chs = [ch.Ch(i) for i in range(10)]
|
| 410 |
+
vrs = [float(i) for i in range(10)]
|
| 411 |
+
|
| 412 |
+
func = lambda a : a[0]*a[1] + (a[2]*a[3])/a[4]
|
| 413 |
+
|
| 414 |
+
chained_result = func(chs).r
|
| 415 |
+
regular_result = func(vrs)
|
| 416 |
+
|
| 417 |
+
self.assertTrue(chained_result == regular_result)
|
| 418 |
+
#print chained_result
|
| 419 |
+
#print regular_result
|
| 420 |
+
|
| 421 |
+
chained_func = func(chs)
|
| 422 |
+
chained_func.replace(chs[0], ch.Ch(50))
|
| 423 |
+
vrs[0] = 50
|
| 424 |
+
|
| 425 |
+
chained_result = chained_func.r
|
| 426 |
+
regular_result = func(vrs)
|
| 427 |
+
|
| 428 |
+
self.assertTrue(chained_result == regular_result)
|
| 429 |
+
#print chained_result
|
| 430 |
+
#print regular_result
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def test_matmatmult(self):
|
| 434 |
+
from .ch import dot
|
| 435 |
+
mtx1 = ch.Ch(np.arange(6).reshape((3,2)))
|
| 436 |
+
mtx2 = ch.Ch(np.arange(8).reshape((2,4))*10)
|
| 437 |
+
|
| 438 |
+
mtx3 = dot(mtx1, mtx2)
|
| 439 |
+
#print mtx1.r
|
| 440 |
+
#print mtx2.r
|
| 441 |
+
#print mtx3.r
|
| 442 |
+
#print mtx3.dr_wrt(mtx1).todense()
|
| 443 |
+
#print mtx3.dr_wrt(mtx2).todense()
|
| 444 |
+
|
| 445 |
+
for mtx in [mtx1, mtx2]:
|
| 446 |
+
oldval = mtx3.r.copy()
|
| 447 |
+
mtxd = mtx3.dr_wrt(mtx).copy()
|
| 448 |
+
mtx_diff = np.random.rand(mtx.r.size).reshape(mtx.r.shape)
|
| 449 |
+
mtx.x = mtx.r + mtx_diff
|
| 450 |
+
mtx_emp = mtx3.r - oldval
|
| 451 |
+
mtx_pred = mtxd.dot(mtx_diff.ravel()).reshape(mtx_emp.shape)
|
| 452 |
+
|
| 453 |
+
self.assertTrue(np.max(np.abs(mtx_emp - mtx_pred)) < 1e-11)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def test_ndim(self):
|
| 457 |
+
vs = [ch.Ch(np.random.randn(6).reshape(2,3)) for i in range(6)]
|
| 458 |
+
res = vs[0] + vs[1] - vs[2] * vs[3] / (vs[4] ** 2) ** vs[5]
|
| 459 |
+
self.assertTrue(res.shape[0]==2 and res.shape[1]==3)
|
| 460 |
+
res = (vs[0] + 1) + (vs[1] - 2) - (vs[2] * 3) * (vs[3] / 4) / (vs[4] ** 2) ** vs[5]
|
| 461 |
+
self.assertTrue(res.shape[0]==2 and res.shape[1]==3)
|
| 462 |
+
drs = [res.dr_wrt(v) for v in vs]
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def test_indexing(self):
|
| 466 |
+
big = ch.Ch(np.arange(60).reshape((10,6)))
|
| 467 |
+
little = big[1:3, 3:6]
|
| 468 |
+
self.assertTrue(np.max(np.abs(little.r - np.array([[9,10,11],[15,16,17]]))) == 0)
|
| 469 |
+
|
| 470 |
+
little = big[5]
|
| 471 |
+
self.assertTrue(np.max(np.abs(little.r - np.arange(30, 36))) == 0)
|
| 472 |
+
self.assertTrue(np.max(np.abs(sp.coo_matrix(little.dr_wrt(big)).col - np.arange(30,36))) == 0)
|
| 473 |
+
|
| 474 |
+
little = big[2, 3]
|
| 475 |
+
self.assertTrue(little.r[0] == 15.0)
|
| 476 |
+
|
| 477 |
+
little = big[2, 3:5]
|
| 478 |
+
self.assertTrue(np.max(np.abs(little.r - np.array([15, 16]))) == 0.)
|
| 479 |
+
_ = little.dr_wrt(big)
|
| 480 |
+
|
| 481 |
+
# Tests assignment through reorderings
|
| 482 |
+
aa = ch.arange(4*4*4).reshape((4,4,4))[:3,:3,:3]
|
| 483 |
+
aa[0,1,2] = 100
|
| 484 |
+
self.assertTrue(aa[0,1,2].r[0] == 100)
|
| 485 |
+
|
| 486 |
+
# Tests assignment through reorderings (NaN's are a special case)
|
| 487 |
+
aa = ch.arange(9).reshape((3,3))
|
| 488 |
+
aa[1,1] = np.nan
|
| 489 |
+
self.assertTrue(np.isnan(aa.r[1,1]))
|
| 490 |
+
self.assertFalse(np.isnan(aa.r[0,0]))
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def test_redundancy_removal(self):
|
| 494 |
+
|
| 495 |
+
for MT in [False, True]:
|
| 496 |
+
x1, x2 = ch.Ch(10), ch.Ch(20)
|
| 497 |
+
x1_plus_x2_1 = x1 + x2
|
| 498 |
+
x1_plus_x2_2 = x1 + x2
|
| 499 |
+
redundant_sum = (x1_plus_x2_1 + x1_plus_x2_2) * 2
|
| 500 |
+
redundant_sum.MT = MT
|
| 501 |
+
|
| 502 |
+
self.assertTrue(redundant_sum.a.a is not redundant_sum.a.b)
|
| 503 |
+
redundant_sum.remove_redundancy()
|
| 504 |
+
self.assertTrue(redundant_sum.a.a is redundant_sum.a.b)
|
| 505 |
+
|
| 506 |
+
def test_caching(self):
|
| 507 |
+
|
| 508 |
+
vals = [10, 20, 30, 40, 50]
|
| 509 |
+
f = lambda a, b, c, d, e : a + (b * c) - d ** e
|
| 510 |
+
|
| 511 |
+
# Set up our objects
|
| 512 |
+
Cs = [ch.Ch(v) for v in vals]
|
| 513 |
+
C_result = f(*Cs)
|
| 514 |
+
|
| 515 |
+
# Sometimes residuals should be cached
|
| 516 |
+
r1 = C_result.r
|
| 517 |
+
r2 = C_result.r
|
| 518 |
+
self.assertTrue(r1 is r2)
|
| 519 |
+
|
| 520 |
+
# Other times residuals need refreshing
|
| 521 |
+
Cs[0].set(x=5)
|
| 522 |
+
r3 = C_result.r
|
| 523 |
+
self.assertTrue(r3 is not r2)
|
| 524 |
+
|
| 525 |
+
# Sometimes derivatives should be cached
|
| 526 |
+
dr1 = C_result.dr_wrt(Cs[1])
|
| 527 |
+
dr2 = C_result.dr_wrt(Cs[1])
|
| 528 |
+
self.assertTrue(dr1 is dr2)
|
| 529 |
+
|
| 530 |
+
# Other times derivatives need refreshing
|
| 531 |
+
Cs[2].set(x=5)
|
| 532 |
+
dr3 = C_result.dr_wrt(Cs[1])
|
| 533 |
+
self.assertTrue(dr3 is not dr2)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def test_scalars(self):
|
| 537 |
+
|
| 538 |
+
try:
|
| 539 |
+
import theano.tensor as T
|
| 540 |
+
from theano import function
|
| 541 |
+
except:
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
# Set up variables and function
|
| 545 |
+
vals = [1, 2, 3, 4, 5]
|
| 546 |
+
f = lambda a, b, c, d, e : a + (b * c) - d ** e
|
| 547 |
+
|
| 548 |
+
# Set up our objects
|
| 549 |
+
Cs = [ch.Ch(v) for v in vals]
|
| 550 |
+
C_result = f(*Cs)
|
| 551 |
+
|
| 552 |
+
# Set up Theano's equivalents
|
| 553 |
+
Ts = T.dscalars('T1', 'T2', 'T3', 'T4', 'T5')
|
| 554 |
+
TF = f(*Ts)
|
| 555 |
+
T_result = function(Ts, TF)
|
| 556 |
+
|
| 557 |
+
# Make sure values and derivatives are equal
|
| 558 |
+
self.assertEqual(C_result.r, T_result(*vals))
|
| 559 |
+
for k in range(len(vals)):
|
| 560 |
+
theano_derivative = function(Ts, T.grad(TF, Ts[k]))(*vals)
|
| 561 |
+
#print C_result.dr_wrt(Cs[k])
|
| 562 |
+
our_derivative = C_result.dr_wrt(Cs[k])[0,0]
|
| 563 |
+
#print theano_derivative, our_derivative
|
| 564 |
+
self.assertEqual(theano_derivative, our_derivative)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def test_vectors(self):
|
| 568 |
+
|
| 569 |
+
try:
|
| 570 |
+
import theano.tensor as T
|
| 571 |
+
from theano import function
|
| 572 |
+
except:
|
| 573 |
+
return
|
| 574 |
+
|
| 575 |
+
for MT in [False, True]:
|
| 576 |
+
|
| 577 |
+
# Set up variables and function
|
| 578 |
+
vals = [np.random.randn(20) for i in range(5)]
|
| 579 |
+
f = lambda a, b, c, d, e : a + (b * c) - d ** e
|
| 580 |
+
|
| 581 |
+
# Set up our objects
|
| 582 |
+
Cs = [ch.Ch(v) for v in vals]
|
| 583 |
+
C_result = f(*Cs)
|
| 584 |
+
C_result.MT = MT
|
| 585 |
+
|
| 586 |
+
# Set up Theano equivalents
|
| 587 |
+
Ts = T.dvectors('T1', 'T2', 'T3', 'T4', 'T5')
|
| 588 |
+
TF = f(*Ts)
|
| 589 |
+
T_result = function(Ts, TF)
|
| 590 |
+
|
| 591 |
+
if False:
|
| 592 |
+
import theano.gradient
|
| 593 |
+
which = 1
|
| 594 |
+
theano_sse = (TF**2.).sum()
|
| 595 |
+
theano_grad = theano.gradient.grad(theano_sse, Ts[which])
|
| 596 |
+
theano_fn = function(Ts, theano_grad)
|
| 597 |
+
print(theano_fn(*vals))
|
| 598 |
+
C_result_grad = ch.SumOfSquares(C_result).dr_wrt(Cs[which])
|
| 599 |
+
print(C_result_grad)
|
| 600 |
+
|
| 601 |
+
# if True:
|
| 602 |
+
# aaa = np.linalg.solve(C_result_grad.T.dot(C_result_grad), C_result_grad.dot(np.zeros(C_result_grad.shape[1])))
|
| 603 |
+
# theano_hes = theano.R_obbb = theano.R_op()
|
| 604 |
+
|
| 605 |
+
import pdb; pdb.set_trace()
|
| 606 |
+
|
| 607 |
+
# Make sure values and derivatives are equal
|
| 608 |
+
np.testing.assert_array_equal(C_result.r, T_result(*vals))
|
| 609 |
+
for k in range(len(vals)):
|
| 610 |
+
theano_derivative = function(Ts, T.jacobian(TF, Ts[k]))(*vals)
|
| 611 |
+
our_derivative = np.array(C_result.dr_wrt(Cs[k]).todense())
|
| 612 |
+
#print theano_derivative, our_derivative
|
| 613 |
+
|
| 614 |
+
# Theano produces has more nans than we do during exponentiation.
|
| 615 |
+
# So we test only on entries where Theano is without NaN's
|
| 616 |
+
without_nans = np.nonzero(np.logical_not(np.isnan(theano_derivative.flatten())))[0]
|
| 617 |
+
np.testing.assert_array_equal(theano_derivative.flatten()[without_nans], our_derivative.flatten()[without_nans])
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
if __name__ == '__main__':
|
| 621 |
+
unittest.main()
|
vendor/chumpy/chumpy/test_inner_composition.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import unittest
|
| 10 |
+
from .ch import Ch, depends_on
|
| 11 |
+
|
| 12 |
+
class TestInnerComposition(unittest.TestCase):
|
| 13 |
+
|
| 14 |
+
def test_ic(self):
|
| 15 |
+
child = Child(a=Ch(10))
|
| 16 |
+
parent = Parent(child=child, aliased=Ch(50))
|
| 17 |
+
|
| 18 |
+
junk = [parent.aliased_dependency for k in range(3)]
|
| 19 |
+
self.assertTrue(parent.dcount == 1)
|
| 20 |
+
self.assertTrue(parent.ocount == 0)
|
| 21 |
+
self.assertTrue(parent.rcount == 0)
|
| 22 |
+
|
| 23 |
+
junk = [parent.r for k in range(3)]
|
| 24 |
+
self.assertTrue(parent.dcount == 1)
|
| 25 |
+
self.assertTrue(parent.ocount == 1)
|
| 26 |
+
self.assertTrue(parent.rcount == 1)
|
| 27 |
+
|
| 28 |
+
parent.aliased = Ch(20)
|
| 29 |
+
junk = [parent.aliased_dependency for k in range(3)]
|
| 30 |
+
self.assertTrue(parent.dcount == 2)
|
| 31 |
+
self.assertTrue(parent.ocount == 1)
|
| 32 |
+
self.assertTrue(parent.rcount == 1)
|
| 33 |
+
|
| 34 |
+
junk = [parent.r for k in range(3)]
|
| 35 |
+
self.assertTrue(parent.dcount == 2)
|
| 36 |
+
self.assertTrue(parent.ocount == 2)
|
| 37 |
+
self.assertTrue(parent.rcount == 2)
|
| 38 |
+
|
| 39 |
+
class Parent(Ch):
|
| 40 |
+
dterms = ('aliased', 'child')
|
| 41 |
+
|
| 42 |
+
def __init__(self, *args, **kwargs):
|
| 43 |
+
self.dcount = 0
|
| 44 |
+
self.ocount = 0
|
| 45 |
+
self.rcount = 0
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def on_changed(self, which):
|
| 49 |
+
assert('aliased' in which and 'child' in which)
|
| 50 |
+
if 'aliased' in which:
|
| 51 |
+
self.ocount += 1
|
| 52 |
+
|
| 53 |
+
@depends_on('aliased')
|
| 54 |
+
def aliased_dependency(self):
|
| 55 |
+
self.dcount += 1
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def aliased(self):
|
| 59 |
+
return self.child.a
|
| 60 |
+
|
| 61 |
+
@aliased.setter
|
| 62 |
+
def aliased(self, val):
|
| 63 |
+
self.child.a = val
|
| 64 |
+
|
| 65 |
+
def compute_r(self):
|
| 66 |
+
self.rcount += 1
|
| 67 |
+
return 0
|
| 68 |
+
|
| 69 |
+
def compute_dr_wrt(self, wrt):
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Child(Ch):
|
| 74 |
+
dterms = ('a',)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == '__main__':
|
| 79 |
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestInnerComposition)
|
| 80 |
+
unittest.TextTestRunner(verbosity=2).run(suite)
|
vendor/chumpy/chumpy/test_linalg.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import unittest
|
| 11 |
+
|
| 12 |
+
from .ch import Ch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestLinalg(unittest.TestCase):
|
| 18 |
+
|
| 19 |
+
def setUp(self):
|
| 20 |
+
np.random.seed(0)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_slogdet(self):
|
| 24 |
+
from . import ch
|
| 25 |
+
tmp = ch.random.randn(100).reshape((10,10))
|
| 26 |
+
# print 'chumpy version: ' + str(slogdet(tmp)[1].r)
|
| 27 |
+
# print 'old version:' + str(np.linalg.slogdet(tmp.r)[1])
|
| 28 |
+
|
| 29 |
+
eps = 1e-10
|
| 30 |
+
diff = np.random.rand(100) * eps
|
| 31 |
+
diff_reshaped = diff.reshape((10,10))
|
| 32 |
+
gt = np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1]
|
| 33 |
+
pred = ch.linalg.slogdet(tmp)[1].dr_wrt(tmp).dot(diff)
|
| 34 |
+
#print gt
|
| 35 |
+
#print pred
|
| 36 |
+
diff = gt - pred
|
| 37 |
+
|
| 38 |
+
self.assertTrue(np.max(np.abs(diff)) < 1e-12)
|
| 39 |
+
|
| 40 |
+
sgn_gt = np.linalg.slogdet(tmp.r)[0]
|
| 41 |
+
sgn_pred = ch.linalg.slogdet(tmp)[0]
|
| 42 |
+
|
| 43 |
+
#print sgn_gt
|
| 44 |
+
#print sgn_pred
|
| 45 |
+
diff = sgn_gt - sgn_pred.r
|
| 46 |
+
self.assertTrue(np.max(np.abs(diff)) < 1e-12)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_lstsq(self):
|
| 50 |
+
from .linalg import lstsq
|
| 51 |
+
|
| 52 |
+
shapes = ([10, 3], [3, 10])
|
| 53 |
+
|
| 54 |
+
for shape in shapes:
|
| 55 |
+
for b2d in True, False:
|
| 56 |
+
A = (np.random.rand(np.prod(shape))-.5).reshape(shape)
|
| 57 |
+
if b2d:
|
| 58 |
+
b = np.random.randn(shape[0],2)
|
| 59 |
+
else:
|
| 60 |
+
b = np.random.randn(shape[0])
|
| 61 |
+
|
| 62 |
+
x1, residuals1, rank1, s1 = lstsq(A, b)
|
| 63 |
+
x2, residuals2, rank2, s2 = np.linalg.lstsq(A, b)
|
| 64 |
+
|
| 65 |
+
#print x1.r
|
| 66 |
+
#print x2
|
| 67 |
+
#print residuals1.r
|
| 68 |
+
#print residuals2
|
| 69 |
+
self.assertTrue(np.max(np.abs(x1.r-x2)) < 1e-14)
|
| 70 |
+
if len(residuals2) > 0:
|
| 71 |
+
self.assertTrue(np.max(np.abs(residuals1.r-residuals2)) < 1e-14)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_pinv(self):
|
| 77 |
+
from .linalg import Pinv
|
| 78 |
+
|
| 79 |
+
data = (np.random.rand(12)-.5).reshape((3, 4))
|
| 80 |
+
pc_tall = Pinv(data)
|
| 81 |
+
pc_wide = Pinv(data.T)
|
| 82 |
+
|
| 83 |
+
pn_tall = np.linalg.pinv(data)
|
| 84 |
+
pn_wide = np.linalg.pinv(data.T)
|
| 85 |
+
|
| 86 |
+
tall_correct = np.max(np.abs(pc_tall.r - pn_tall)) < 1e-12
|
| 87 |
+
wide_correct = np.max(np.abs(pc_wide.r - pn_wide)) < 1e-12
|
| 88 |
+
# if not tall_correct or not wide_correct:
|
| 89 |
+
# print tall_correct
|
| 90 |
+
# print wide_correct
|
| 91 |
+
# import pdb; pdb.set_trace()
|
| 92 |
+
self.assertTrue(tall_correct)
|
| 93 |
+
self.assertTrue(wide_correct)
|
| 94 |
+
|
| 95 |
+
return # FIXME. how to test derivs?
|
| 96 |
+
|
| 97 |
+
for pc in [pc_tall, pc_wide]:
|
| 98 |
+
|
| 99 |
+
self.chkd(pc, pc.mtx)
|
| 100 |
+
import pdb; pdb.set_trace()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_svd(self):
|
| 105 |
+
from .linalg import Svd
|
| 106 |
+
eps = 1e-3
|
| 107 |
+
idx = 10
|
| 108 |
+
|
| 109 |
+
data = np.sin(np.arange(300)*100+10).reshape((-1,3))
|
| 110 |
+
data[3,:] = data[3,:]*0+10
|
| 111 |
+
data[:,1] *= 2
|
| 112 |
+
data[:,2] *= 4
|
| 113 |
+
data = data.copy()
|
| 114 |
+
u,s,v = np.linalg.svd(data, full_matrices=False)
|
| 115 |
+
data = Ch(data)
|
| 116 |
+
data2 = data.r.copy()
|
| 117 |
+
data2.ravel()[idx] += eps
|
| 118 |
+
u2,s2,v2 = np.linalg.svd(data2, full_matrices=False)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
svdu, svdd, svdv = Svd(x=data)
|
| 122 |
+
|
| 123 |
+
# test singular values
|
| 124 |
+
diff_emp = (s2-s) / eps
|
| 125 |
+
diff_pred = svdd.dr_wrt(data)[:,idx]
|
| 126 |
+
#print diff_emp
|
| 127 |
+
#print diff_pred
|
| 128 |
+
ratio = diff_emp / diff_pred
|
| 129 |
+
#print ratio
|
| 130 |
+
self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-4)
|
| 131 |
+
|
| 132 |
+
# test V
|
| 133 |
+
diff_emp = (v2 - v) / eps
|
| 134 |
+
diff_pred = svdv.dr_wrt(data)[:,idx].reshape(diff_emp.shape)
|
| 135 |
+
ratio = diff_emp / diff_pred
|
| 136 |
+
#print ratio
|
| 137 |
+
self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-2)
|
| 138 |
+
|
| 139 |
+
# test U
|
| 140 |
+
diff_emp = (u2 - u) / eps
|
| 141 |
+
diff_pred = svdu.dr_wrt(data)[:,idx].reshape(diff_emp.shape)
|
| 142 |
+
ratio = diff_emp / diff_pred
|
| 143 |
+
#print ratio
|
| 144 |
+
self.assertTrue(np.max(np.abs(ratio - 1.)) < 1e-2)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_det(self):
|
| 148 |
+
from .linalg import Det
|
| 149 |
+
|
| 150 |
+
mtx1 = Ch(np.sin(2**np.arange(9)).reshape((3,3)))
|
| 151 |
+
mtx1_det = Det(mtx1)
|
| 152 |
+
dr = mtx1_det.dr_wrt(mtx1)
|
| 153 |
+
|
| 154 |
+
eps = 1e-5
|
| 155 |
+
mtx2 = mtx1.r.copy()
|
| 156 |
+
input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
|
| 157 |
+
mtx2 += input_diff
|
| 158 |
+
mtx2_det = Det(mtx2)
|
| 159 |
+
|
| 160 |
+
output_diff_emp = (np.linalg.det(mtx2) - np.linalg.det(mtx1.r)).ravel()
|
| 161 |
+
|
| 162 |
+
output_diff_pred = Det(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
|
| 163 |
+
|
| 164 |
+
#print output_diff_emp
|
| 165 |
+
#print output_diff_pred
|
| 166 |
+
|
| 167 |
+
self.assertTrue(np.max(np.abs(output_diff_emp - output_diff_pred)) < eps*1e-4)
|
| 168 |
+
self.assertTrue(np.max(np.abs(mtx1_det.r - np.linalg.det(mtx1.r)).ravel()) == 0)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def test_inv1(self):
|
| 173 |
+
from .linalg import Inv
|
| 174 |
+
|
| 175 |
+
mtx1 = Ch(np.sin(2**np.arange(9)).reshape((3,3)))
|
| 176 |
+
mtx1_inv = Inv(mtx1)
|
| 177 |
+
dr = mtx1_inv.dr_wrt(mtx1)
|
| 178 |
+
|
| 179 |
+
eps = 1e-5
|
| 180 |
+
mtx2 = mtx1.r.copy()
|
| 181 |
+
input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
|
| 182 |
+
mtx2 += input_diff
|
| 183 |
+
mtx2_inv = Inv(mtx2)
|
| 184 |
+
|
| 185 |
+
output_diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1.r)).ravel()
|
| 186 |
+
output_diff_pred = Inv(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
|
| 187 |
+
|
| 188 |
+
#print output_diff_emp
|
| 189 |
+
#print output_diff_pred
|
| 190 |
+
|
| 191 |
+
self.assertTrue(np.max(np.abs(output_diff_emp - output_diff_pred)) < eps*1e-4)
|
| 192 |
+
self.assertTrue(np.max(np.abs(mtx1_inv.r - np.linalg.inv(mtx1.r)).ravel()) == 0)
|
| 193 |
+
|
| 194 |
+
def test_inv2(self):
|
| 195 |
+
from .linalg import Inv
|
| 196 |
+
|
| 197 |
+
eps = 1e-8
|
| 198 |
+
idx = 13
|
| 199 |
+
|
| 200 |
+
mtx1 = np.random.rand(100).reshape((10,10))
|
| 201 |
+
mtx2 = mtx1.copy()
|
| 202 |
+
mtx2.ravel()[idx] += eps
|
| 203 |
+
|
| 204 |
+
diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1)) / eps
|
| 205 |
+
|
| 206 |
+
mtx1 = Ch(mtx1)
|
| 207 |
+
diff_pred = Inv(mtx1).dr_wrt(mtx1)[:,13].reshape(diff_emp.shape)
|
| 208 |
+
#print diff_emp
|
| 209 |
+
#print diff_pred
|
| 210 |
+
#print diff_emp - diff_pred
|
| 211 |
+
self.assertTrue(np.max(np.abs(diff_pred.ravel()-diff_emp.ravel())) < 1e-4)
|
| 212 |
+
|
| 213 |
+
@unittest.skipIf(np.__version__ < '1.8',
|
| 214 |
+
"broadcasting for matrix inverse not supported in numpy < 1.8")
|
| 215 |
+
def test_inv3(self):
|
| 216 |
+
"""Test linalg.inv with broadcasting support."""
|
| 217 |
+
|
| 218 |
+
from .linalg import Inv
|
| 219 |
+
|
| 220 |
+
mtx1 = Ch(np.sin(2**np.arange(12)).reshape((3,2,2)))
|
| 221 |
+
mtx1_inv = Inv(mtx1)
|
| 222 |
+
dr = mtx1_inv.dr_wrt(mtx1)
|
| 223 |
+
|
| 224 |
+
eps = 1e-5
|
| 225 |
+
mtx2 = mtx1.r.copy()
|
| 226 |
+
input_diff = np.sin(np.arange(mtx2.size)).reshape(mtx2.shape) * eps
|
| 227 |
+
mtx2 += input_diff
|
| 228 |
+
mtx2_inv = Inv(mtx2)
|
| 229 |
+
|
| 230 |
+
output_diff_emp = (np.linalg.inv(mtx2) - np.linalg.inv(mtx1.r)).ravel()
|
| 231 |
+
output_diff_pred = Inv(mtx1).dr_wrt(mtx1).dot(input_diff.ravel())
|
| 232 |
+
|
| 233 |
+
# print output_diff_emp
|
| 234 |
+
# print output_diff_pred
|
| 235 |
+
|
| 236 |
+
self.assertTrue(np.max(np.abs(output_diff_emp.ravel() - output_diff_pred.ravel())) < eps*1e-3)
|
| 237 |
+
self.assertTrue(np.max(np.abs(mtx1_inv.r - np.linalg.inv(mtx1.r)).ravel()) == 0)
|
| 238 |
+
|
| 239 |
+
def chkd(self, obj, parm, eps=1e-14):
|
| 240 |
+
backed_up = parm.x
|
| 241 |
+
|
| 242 |
+
if True:
|
| 243 |
+
diff = (np.random.rand(parm.size)-.5).reshape(parm.shape)
|
| 244 |
+
else:
|
| 245 |
+
diff = np.zeros(parm.shape)
|
| 246 |
+
diff.ravel()[4] = 2.
|
| 247 |
+
|
| 248 |
+
dr = obj.dr_wrt(parm)
|
| 249 |
+
|
| 250 |
+
parm.x = backed_up - diff*eps
|
| 251 |
+
r_lower = obj.r
|
| 252 |
+
|
| 253 |
+
parm.x = backed_up + diff*eps
|
| 254 |
+
r_upper = obj.r
|
| 255 |
+
|
| 256 |
+
diff_emp = (r_upper - r_lower) / (eps*2.)
|
| 257 |
+
diff_pred = dr.dot(diff.ravel()).reshape(diff_emp.shape)
|
| 258 |
+
|
| 259 |
+
#print diff_emp
|
| 260 |
+
#print diff_pred
|
| 261 |
+
print(diff_emp / diff_pred)
|
| 262 |
+
print(diff_emp - diff_pred)
|
| 263 |
+
|
| 264 |
+
parm.x = backed_up
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestLinalg)
|
| 269 |
+
|
| 270 |
+
if __name__ == '__main__':
|
| 271 |
+
unittest.main()
|
| 272 |
+
|
vendor/chumpy/chumpy/test_optimization.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# encoding: utf-8
|
| 3 |
+
"""
|
| 4 |
+
Author(s): Matthew Loper
|
| 5 |
+
|
| 6 |
+
See LICENCE.txt for licensing and contact information.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
from numpy import *
|
| 11 |
+
import unittest
|
| 12 |
+
from . import ch
|
| 13 |
+
from .optimization import minimize
|
| 14 |
+
from .ch import Ch
|
| 15 |
+
import numpy as np
|
| 16 |
+
from scipy.optimize import rosen, rosen_der
|
| 17 |
+
from .utils import row, col
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
visualize = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def Rosen():
|
| 24 |
+
|
| 25 |
+
args = {
|
| 26 |
+
'x1': Ch(-120.),
|
| 27 |
+
'x2': Ch(-100.)
|
| 28 |
+
}
|
| 29 |
+
r1 = Ch(lambda x1, x2 : (x2 - x1**2.) * 10., args)
|
| 30 |
+
r2 = Ch(lambda x1 : x1 * -1. + 1, args)
|
| 31 |
+
|
| 32 |
+
func = [r1, r2]
|
| 33 |
+
|
| 34 |
+
return func, [args['x1'], args['x2']]
|
| 35 |
+
|
| 36 |
+
class Madsen(Ch):
|
| 37 |
+
dterms = ('x',)
|
| 38 |
+
def compute_r(self):
|
| 39 |
+
x1 = self.x.r[0]
|
| 40 |
+
x2 = self.x.r[1]
|
| 41 |
+
result = np.array((
|
| 42 |
+
x1**2 + x2**2 + x1 * x2,
|
| 43 |
+
np.sin(x1),
|
| 44 |
+
np.cos(x2)
|
| 45 |
+
))
|
| 46 |
+
return result
|
| 47 |
+
|
| 48 |
+
def compute_dr_wrt(self, wrt):
|
| 49 |
+
if wrt is not self.x:
|
| 50 |
+
return None
|
| 51 |
+
jac = np.zeros((3,2))
|
| 52 |
+
x1 = self.x.r[0]
|
| 53 |
+
x2 = self.x.r[1]
|
| 54 |
+
jac[0,0] = 2. * x1 + x2
|
| 55 |
+
jac[0,1] = 2. * x2 + x1
|
| 56 |
+
|
| 57 |
+
jac[1,0] = np.cos(x1)
|
| 58 |
+
jac[1,1] = 0
|
| 59 |
+
|
| 60 |
+
jac[2,0] = 0
|
| 61 |
+
jac[2,1] = -np.sin(x2)
|
| 62 |
+
|
| 63 |
+
return jac
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def set_and_get_r(self, x_in):
|
| 67 |
+
self.x = Ch(x_in)
|
| 68 |
+
return col(self.r)
|
| 69 |
+
|
| 70 |
+
def set_and_get_dr(self, x_in):
|
| 71 |
+
self.x = Ch(x_in)
|
| 72 |
+
return self.dr_wrt(self.x)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class RosenCh(Ch):
|
| 78 |
+
dterms = ('x',)
|
| 79 |
+
def compute_r(self):
|
| 80 |
+
|
| 81 |
+
result = np.array((rosen(self.x.r) ))
|
| 82 |
+
|
| 83 |
+
return result
|
| 84 |
+
|
| 85 |
+
def set_and_get_r(self, x_in):
|
| 86 |
+
self.x = Ch(x_in)
|
| 87 |
+
return col(self.r)
|
| 88 |
+
|
| 89 |
+
def set_and_get_dr(self, x_in):
|
| 90 |
+
self.x = Ch(x_in)
|
| 91 |
+
return self.dr_wrt(self.x).flatten()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def compute_dr_wrt(self, wrt):
|
| 95 |
+
if wrt is self.x:
|
| 96 |
+
if visualize:
|
| 97 |
+
import matplotlib.pyplot as plt
|
| 98 |
+
residuals = np.sum(self.r**2)
|
| 99 |
+
print('------> RESIDUALS %.2e' % (residuals,))
|
| 100 |
+
print('------> CURRENT GUESS %s' % (str(self.x.r),))
|
| 101 |
+
plt.figure(123)
|
| 102 |
+
|
| 103 |
+
if not hasattr(self, 'vs'):
|
| 104 |
+
self.vs = []
|
| 105 |
+
self.xs = []
|
| 106 |
+
self.ys = []
|
| 107 |
+
self.vs.append(residuals)
|
| 108 |
+
self.xs.append(self.x.r[0])
|
| 109 |
+
self.ys.append(self.x.r[1])
|
| 110 |
+
plt.clf();
|
| 111 |
+
plt.subplot(1,2,1)
|
| 112 |
+
plt.plot(self.vs)
|
| 113 |
+
plt.subplot(1,2,2)
|
| 114 |
+
plt.plot(self.xs, self.ys)
|
| 115 |
+
plt.draw()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
return row(rosen_der(self.x.r))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TestOptimization(unittest.TestCase):
|
| 123 |
+
|
| 124 |
+
def test_dogleg_rosen(self):
|
| 125 |
+
obj, freevars = Rosen()
|
| 126 |
+
minimize(fun=obj, x0=freevars, method='dogleg', options={'maxiter': 337, 'disp': False})
|
| 127 |
+
self.assertTrue(freevars[0].r[0]==1.)
|
| 128 |
+
self.assertTrue(freevars[1].r[0]==1.)
|
| 129 |
+
|
| 130 |
+
def test_dogleg_madsen(self):
|
| 131 |
+
obj = Madsen(x = Ch(np.array((3.,1.))))
|
| 132 |
+
minimize(fun=obj, x0=[obj.x], method='dogleg', options={'maxiter': 34, 'disp': False})
|
| 133 |
+
self.assertTrue(np.sum(obj.r**2)/2 < 0.386599528247)
|
| 134 |
+
|
| 135 |
+
@unittest.skip('negative sign in exponent screws with reverse mode')
|
| 136 |
+
def test_bfgs_rosen(self):
|
| 137 |
+
from .optimization import minimize_bfgs_lsq
|
| 138 |
+
obj, freevars = Rosen()
|
| 139 |
+
minimize_bfgs_lsq(obj=obj, niters=421, verbose=False, free_variables=freevars)
|
| 140 |
+
self.assertTrue(freevars[0].r[0]==1.)
|
| 141 |
+
self.assertTrue(freevars[1].r[0]==1.)
|
| 142 |
+
|
| 143 |
+
def test_bfgs_madsen(self):
|
| 144 |
+
from .ch import SumOfSquares
|
| 145 |
+
import scipy.optimize
|
| 146 |
+
obj = Ch(lambda x : SumOfSquares(Madsen(x = x)) )
|
| 147 |
+
|
| 148 |
+
def errfunc(x):
|
| 149 |
+
obj.x = Ch(x)
|
| 150 |
+
return obj.r
|
| 151 |
+
|
| 152 |
+
def gradfunc(x):
|
| 153 |
+
obj.x = Ch(x)
|
| 154 |
+
return obj.dr_wrt(obj.x).ravel()
|
| 155 |
+
|
| 156 |
+
x0 = np.array((3., 1.))
|
| 157 |
+
|
| 158 |
+
# Optimize with built-in bfgs.
|
| 159 |
+
# Note: with 8 iters, this actually requires 14 gradient evaluations.
|
| 160 |
+
# This can be verified by setting "disp" to 1.
|
| 161 |
+
#tm = time.time()
|
| 162 |
+
x1 = scipy.optimize.fmin_bfgs(errfunc, x0, fprime=gradfunc, maxiter=8, disp=0)
|
| 163 |
+
#print 'forward: took %.es' % (time.time() - tm,)
|
| 164 |
+
self.assertLess(obj.r/2., 0.4)
|
| 165 |
+
|
| 166 |
+
# Optimize with chumpy's minimize (which uses scipy's bfgs).
|
| 167 |
+
obj.x = x0
|
| 168 |
+
minimize(fun=obj, x0=[obj.x], method='bfgs', options={'maxiter': 8, 'disp': False})
|
| 169 |
+
self.assertLess(obj.r/2., 0.4)
|
| 170 |
+
|
| 171 |
+
def test_nested_select(self):
|
| 172 |
+
def beales(x, y):
|
| 173 |
+
e1 = 1.5 - x + x*y
|
| 174 |
+
e2 = 2.25 - x + x*(y**2)
|
| 175 |
+
e3 = 2.625 - x + x*(y**3)
|
| 176 |
+
return {'e1': e1, 'e2': e2, 'e3': e3}
|
| 177 |
+
|
| 178 |
+
x1 = ch.zeros(10)
|
| 179 |
+
y1 = ch.zeros(10)
|
| 180 |
+
|
| 181 |
+
# With a single select this worked
|
| 182 |
+
minimize(beales(x1, y1), x0=[x1[1:4], y1], method='dogleg', options={'disp': False})
|
| 183 |
+
|
| 184 |
+
x2 = ch.zeros(10)
|
| 185 |
+
y2 = ch.zeros(10)
|
| 186 |
+
|
| 187 |
+
# But this used to raise `AttributeError: 'Select' object has no attribute 'x'`
|
| 188 |
+
minimize(beales(x2, y2), x0=[x2[1:8][:3], y2], method='dogleg', options={'disp': False})
|
| 189 |
+
np.testing.assert_array_equal(x1, x2)
|
| 190 |
+
np.testing.assert_array_equal(y1, y2)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestOptimization)
|
| 194 |
+
|
| 195 |
+
if __name__ == '__main__':
|
| 196 |
+
|
| 197 |
+
if False: # show rosen
|
| 198 |
+
import matplotlib.pyplot as plt
|
| 199 |
+
visualize = True
|
| 200 |
+
plt.ion()
|
| 201 |
+
unittest.main()
|
| 202 |
+
import pdb; pdb.set_trace()
|
| 203 |
+
else:
|
| 204 |
+
unittest.main()
|
vendor/chumpy/chumpy/testing.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import ch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
fn1 = 'assert_allclose', 'assert_almost_equal', 'assert_approx_equal', 'assert_array_almost_equal', 'assert_array_almost_equal_nulp', 'assert_array_equal', 'assert_array_less', 'assert_array_max_ulp', 'assert_equal', 'assert_no_warnings', 'assert_string_equal'
|
| 5 |
+
fn2 = 'assert_raises', 'assert_warns'
|
| 6 |
+
|
| 7 |
+
# These are unhandled
|
| 8 |
+
fn3 = 'build_err_msg', 'dec', 'decorate_methods', 'decorators', 'division', 'importall', 'jiffies', 'measure', 'memusage', 'nosetester', 'numpytest', 'print_assert_equal', 'print_function', 'raises', 'rand', 'run_module_suite', 'rundocs', 'runstring', 'test', 'utils', 'verbose'
|
| 9 |
+
|
| 10 |
+
__all__ = fn1 + fn2
|
| 11 |
+
|
| 12 |
+
for rtn in fn1:
|
| 13 |
+
exec('def %s(*args, **kwargs) : return np.testing.%s(np.asarray(args[0]), np.asarray(args[1]), *args[2:], **kwargs)' % (rtn, rtn))
|
| 14 |
+
|
| 15 |
+
for rtn in fn2:
|
| 16 |
+
exec('def %s(*args, **kwargs) : return np.testing.%s(*args, **kwargs)' % (rtn, rtn))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
main()
|
vendor/chumpy/chumpy/utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author(s): Matthew Loper
|
| 3 |
+
|
| 4 |
+
See LICENCE.txt for licensing and contact information.
|
| 5 |
+
"""
|
| 6 |
+
import scipy.sparse as sp
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
def row(A):
|
| 10 |
+
return A.reshape((1, -1))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def col(A):
|
| 14 |
+
return A.reshape((-1, 1))
|
| 15 |
+
|
| 16 |
+
class timer(object):
|
| 17 |
+
def time(self):
|
| 18 |
+
import time
|
| 19 |
+
return time.time()
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self._elapsed = 0
|
| 22 |
+
self._start = self.time()
|
| 23 |
+
def __call__(self):
|
| 24 |
+
if self._start is not None:
|
| 25 |
+
return self._elapsed + self.time() - self._start
|
| 26 |
+
else:
|
| 27 |
+
return self._elapsed
|
| 28 |
+
def pause(self):
|
| 29 |
+
assert self._start is not None
|
| 30 |
+
self._elapsed += self.time() - self._start
|
| 31 |
+
self._start = None
|
| 32 |
+
def resume(self):
|
| 33 |
+
assert self._start is None
|
| 34 |
+
self._start = self.time()
|
| 35 |
+
|
| 36 |
+
def dfs_do_func_on_graph(node, func, *args, **kwargs):
|
| 37 |
+
'''
|
| 38 |
+
invoke func on each node of the dr graph
|
| 39 |
+
'''
|
| 40 |
+
for _node in node.tree_iterator():
|
| 41 |
+
func(_node, *args, **kwargs)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def sparse_is_desireable(lhs, rhs):
|
| 45 |
+
'''
|
| 46 |
+
Examines a pair of matrices and determines if the result of their multiplication should be sparse or not.
|
| 47 |
+
'''
|
| 48 |
+
return False
|
| 49 |
+
if len(lhs.shape) == 1:
|
| 50 |
+
return False
|
| 51 |
+
else:
|
| 52 |
+
lhs_rows, lhs_cols = lhs.shape
|
| 53 |
+
|
| 54 |
+
if len(rhs.shape) == 1:
|
| 55 |
+
rhs_rows = 1
|
| 56 |
+
rhs_cols = rhs.size
|
| 57 |
+
else:
|
| 58 |
+
rhs_rows, rhs_cols = rhs.shape
|
| 59 |
+
|
| 60 |
+
result_size = lhs_rows * rhs_cols
|
| 61 |
+
|
| 62 |
+
if sp.issparse(lhs) and sp.issparse(rhs):
|
| 63 |
+
return True
|
| 64 |
+
elif sp.issparse(lhs):
|
| 65 |
+
lhs_zero_rows = lhs_rows - np.unique(lhs.nonzero()[0]).size
|
| 66 |
+
rhs_zero_cols = np.all(rhs==0, axis=0).sum()
|
| 67 |
+
|
| 68 |
+
elif sp.issparse(rhs):
|
| 69 |
+
lhs_zero_rows = np.all(lhs==0, axis=1).sum()
|
| 70 |
+
rhs_zero_cols = rhs_cols- np.unique(rhs.nonzero()[1]).size
|
| 71 |
+
else:
|
| 72 |
+
lhs_zero_rows = np.all(lhs==0, axis=1).sum()
|
| 73 |
+
rhs_zero_cols = np.all(rhs==0, axis=0).sum()
|
| 74 |
+
|
| 75 |
+
num_zeros = lhs_zero_rows * rhs_cols + rhs_zero_cols * lhs_rows - lhs_zero_rows * rhs_zero_cols
|
| 76 |
+
|
| 77 |
+
# A sparse matrix uses roughly 16 bytes per nonzero element (8 + 2 4-byte inds), while a dense matrix uses 8 bytes per element. So the break even point for sparsity is 50% nonzero. But in practice, it seems to be that the compression in a csc or csr matrix gets us break even at ~65% nonzero, which lets us say 50% is a conservative, worst cases cutoff.
|
| 78 |
+
return (float(num_zeros) / float(size)) >= 0.5
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def convert_inputs_to_sparse_if_necessary(lhs, rhs):
|
| 82 |
+
'''
|
| 83 |
+
This function checks to see if a sparse output is desireable given the inputs and if so, casts the inputs to sparse in order to make it so.
|
| 84 |
+
'''
|
| 85 |
+
if not sp.issparse(lhs) or not sp.issparse(rhs):
|
| 86 |
+
if sparse_is_desireable(lhs, rhs):
|
| 87 |
+
if not sp.issparse(lhs):
|
| 88 |
+
lhs = sp.csc_matrix(lhs)
|
| 89 |
+
#print "converting lhs into sparse matrix"
|
| 90 |
+
if not sp.issparse(rhs):
|
| 91 |
+
rhs = sp.csc_matrix(rhs)
|
| 92 |
+
#print "converting rhs into sparse matrix"
|
| 93 |
+
return lhs, rhs
|
vendor/chumpy/chumpy/version.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version = '0.71'
|
| 2 |
+
short_version = version
|
| 3 |
+
full_version = version
|
vendor/chumpy/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.8.1
|
| 2 |
+
scipy>=0.13.0
|
| 3 |
+
six>=1.11.0
|
vendor/chumpy/setup.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author(s): Matthew Loper
|
| 3 |
+
|
| 4 |
+
See LICENCE.txt for licensing and contact information.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from distutils.core import setup
|
| 8 |
+
from runpy import run_path
|
| 9 |
+
|
| 10 |
+
def get_version():
|
| 11 |
+
namespace = run_path('chumpy/version.py')
|
| 12 |
+
return namespace['version']
|
| 13 |
+
|
| 14 |
+
setup(name='chumpy',
|
| 15 |
+
version=get_version(),
|
| 16 |
+
packages = ['chumpy'],
|
| 17 |
+
author='Matthew Loper',
|
| 18 |
+
author_email='matt.loper@gmail.com',
|
| 19 |
+
url='https://github.com/mattloper/chumpy',
|
| 20 |
+
description='chumpy',
|
| 21 |
+
license='MIT',
|
| 22 |
+
install_requires=['numpy', 'scipy', 'matplotlib'],
|
| 23 |
+
|
| 24 |
+
classifiers=[
|
| 25 |
+
'Development Status :: 4 - Beta',
|
| 26 |
+
'Intended Audience :: Science/Research',
|
| 27 |
+
'Topic :: Scientific/Engineering :: Mathematics',
|
| 28 |
+
'License :: OSI Approved :: MIT License',
|
| 29 |
+
'Programming Language :: Python :: 2',
|
| 30 |
+
'Programming Language :: Python :: 2.7',
|
| 31 |
+
'Programming Language :: Python :: 3',
|
| 32 |
+
'Operating System :: MacOS :: MacOS X',
|
| 33 |
+
'Operating System :: POSIX :: Linux'
|
| 34 |
+
],
|
| 35 |
+
)
|